From 26d4de16ca03877feb256c99ce3c4b03a9bfd23a Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:32:01 -0700 Subject: [PATCH 01/12] [None][perf] Split MLA DSA custom op for piecewise CUDA graph capture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split the monolithic mla_custom_op_inplace into two ops for DSA models: - mla_dsa_proj (Op 1): Token-wise projections (cublas_mm, rope, FP8 quantize, weight scaling). CUDA-graph-capturable — no batch metadata access, no tensor slicing by num_tokens. - mla_dsa_attn_inplace (Op 2): Batch-dependent k cache update, sparse_attn_indexer, and context/generation attention dispatch. Excluded from CUDA graph capture. This enables the piecewise CUDA graph optimizer to capture the compute-heavy projection portion of DSA MLA while keeping the batch-structure-dependent attention dispatch outside the graph. Key design decisions: - Indexer split into pre_indexer_proj (graph-safe) and _update_k_cache (moved to Op 2) to avoid capturing metadata-dependent scatter ops. - All num_tokens slicing deferred to Op 2 so graph capture sees fixed-shape tensors. - Indexer intermediates (q_fp8, k_fp8, k_scale, weights) returned from Op 1 as List[Tensor] and passed explicitly to Op 2 — no stashing on self to avoid CUDA graph address aliasing. - _should_use_short_mha disabled under torch compile for straight-line control flow in Op 1. - Non-DSA MLA unchanged (still uses mla_custom_op_inplace). Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 31 ++++++- tensorrt_llm/_torch/modules/attention.py | 86 ++++++++++++++++++- 2 files changed, 113 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 6d237cf01fe..e483651ae37 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1801,12 +1801,37 @@ def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor): qk_pe, qk_nope, self.scale_fmt == "ue8m0") return fp8_out, scale + @torch.inference_mode() + def pre_indexer( + self, qr: torch.Tensor, hidden_states: torch.Tensor, + metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Token-wise projections, FP8 quantize, weight scaling, and k cache update. + + Runs the full indexer pre-computation including k cache update. + Used by the eager path (Indexer.forward) where everything runs + outside CUDA graph capture. + + Returns (q_fp8, k_fp8, k_scale, weights). + """ + q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj( + qr, hidden_states, position_ids) + + weights, _ = maybe_execute_in_parallel( + lambda: weights, + lambda: self._update_k_cache(k_fp8, k_scale, metadata), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) + + return q_fp8, k_fp8, k_scale, weights + def pre_indexer_proj( self, qr: torch.Tensor, hidden_states: torch.Tensor, position_ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Pure token-wise projections (CUDA-graph-capturable). - Runs cublas_mm, qk_projection_and_rope, FP8 quantize, and weight scaling. Does NOT touch the k cache or any batch-specific metadata, so this can safely run inside a captured CUDA graph partition. @@ -1847,8 +1872,8 @@ def pre_indexer_proj( def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor): - q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj( - qr, hidden_states, position_ids) + q_fp8, k_fp8, k_scale, weights = self.pre_indexer( + qr, hidden_states, metadata, position_ids) # Return topk indices buffer for sparse attention [num_tokens, index_topk] return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index fd0714e25a2..ae20286932a 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1050,6 +1050,84 @@ def mla_dsa_attn_inplace( output) +@torch.library.custom_op("trtllm::mla_dsa_proj", mutates_args=()) +def mla_dsa_proj( + hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor], + layer_idx: str, +) -> List[torch.Tensor]: + """Token-wise projections for DSA MLA (CUDA-graph-capturable). + + Runs kv_a_proj, layernorms, q_b_proj, and conditionally + indexer.pre_indexer (which updates the indexer k cache). + + Returns [q, compressed_kv, k_pe, latent_cache] when the short-MHA path + handles all tokens, or [q, compressed_kv, k_pe, latent_cache, q_fp8, + k_fp8, k_scale, weights] when the indexer runs. Under torch compile, + _should_use_short_mha returns False so the result is always length 8, + keeping control flow straight-line for CUDA graph capture. + """ + metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") + return mla_layer.forward_dsa_proj(position_ids, hidden_states, metadata) + + +@mla_dsa_proj.register_fake +def _mla_dsa_proj_fake( + hidden_states: torch.Tensor, + position_ids: Optional[torch.Tensor], + layer_idx: str, +) -> List[torch.Tensor]: + # Under torch compile _should_use_short_mha is False, so always 8 tensors. + metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") + num_tokens = hidden_states.shape[0] + indexer = mla_layer.mqa.indexer + q = hidden_states.new_empty( + [num_tokens, mla_layer.num_heads_tp * mla_layer.qk_head_dim]) + compressed_kv = hidden_states.new_empty( + [num_tokens, mla_layer.kv_lora_rank]) + k_pe = hidden_states.new_empty([num_tokens, mla_layer.qk_rope_head_dim]) + latent_cache = hidden_states.new_empty( + [num_tokens, mla_layer.kv_lora_rank + mla_layer.qk_rope_head_dim]) + # Indexer intermediates: q_fp8, k_fp8, k_scale, weights + q_fp8 = hidden_states.new_empty( + [num_tokens, indexer.n_heads, indexer.head_dim], + dtype=torch.float8_e4m3fn) + k_fp8 = hidden_states.new_empty([num_tokens, indexer.head_dim], + dtype=torch.float8_e4m3fn) + k_scale = hidden_states.new_empty([num_tokens, indexer.head_dim // 128], + dtype=torch.float32) + weights = hidden_states.new_empty([num_tokens, indexer.n_heads], + dtype=torch.float32) + return [ + q, compressed_kv, k_pe, latent_cache, q_fp8, k_fp8, k_scale, weights + ] + + +@torch.library.custom_op("trtllm::mla_dsa_attn_inplace", + mutates_args=("output", )) +def mla_dsa_attn_inplace( + q: torch.Tensor, + compressed_kv: torch.Tensor, + k_pe: torch.Tensor, + latent_cache: torch.Tensor, + indexer_intermediates: List[torch.Tensor], + position_ids: Optional[torch.Tensor], + layer_idx: str, + output: torch.Tensor, +) -> None: + """Batch-structure-dependent attention dispatch for DSA MLA. + + indexer_intermediates is [q_fp8, k_fp8, k_scale, weights] when the + indexer ran in Op 1, or [] when short-MHA handled all tokens. + Runs sparse_attn_indexer then dispatches context/generation attention. + This op is excluded from CUDA graph capture. + """ + metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") + mla_layer.forward_dsa_attn(q, compressed_kv, k_pe, latent_cache, + indexer_intermediates, position_ids, metadata, + output) + + def fp8_block_scaling_bmm_out( mat1: torch.Tensor, mat2_fp8: torch.Tensor, @@ -1806,6 +1884,9 @@ def forward_dsa_attn( k_fp8 = k_fp8[:num_tokens, ...] k_scale = k_scale[:num_tokens, ...] weights = weights[:num_tokens, ...] + # Update the indexer k cache here (outside CUDA graph) because + # it accesses batch-specific metadata (slot_mapping_fp8/scale). + self.mqa.indexer._update_k_cache(k_fp8, k_scale, attn_metadata) topk_indices = self.mqa.indexer.sparse_attn_indexer( attn_metadata, q, # only used for shape/device in buffer allocation @@ -2781,7 +2862,10 @@ def forward( if self.is_dsa: proj_outputs = torch.ops.trtllm.mla_dsa_proj( hidden_states, position_ids, self.layer_idx_str) - q, compressed_kv, k_pe, latent_cache = proj_outputs[:4] + q, compressed_kv, k_pe, latent_cache = (proj_outputs[0], + proj_outputs[1], + proj_outputs[2], + proj_outputs[3]) indexer_intermediates = proj_outputs[4:] torch.ops.trtllm.mla_dsa_attn_inplace( q, compressed_kv, k_pe, latent_cache, indexer_intermediates, From d2ba3f9050c66efaaf09823e00921f74931d841f Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 24 Mar 2026 07:30:52 -0700 Subject: [PATCH 02/12] [None][fix] Fix pre_indexer bogus parallel pattern and mla_dsa_proj docstring - Remove no-op `lambda: weights` in pre_indexer's maybe_execute_in_parallel; _weight_scale already ran in pre_indexer_proj, so just call _update_k_cache directly. - Fix mla_dsa_proj docstring: k cache update happens in Op 2 (mla_dsa_attn_inplace), not Op 1. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 10 +--------- tensorrt_llm/_torch/modules/attention.py | 4 +++- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index e483651ae37..3553587d268 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1816,15 +1816,7 @@ def pre_indexer( """ q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj( qr, hidden_states, position_ids) - - weights, _ = maybe_execute_in_parallel( - lambda: weights, - lambda: self._update_k_cache(k_fp8, k_scale, metadata), - self.ln_events[0], - self.ln_events[1], - self.aux_stream, - ) - + self._update_k_cache(k_fp8, k_scale, metadata) return q_fp8, k_fp8, k_scale, weights def pre_indexer_proj( diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index ae20286932a..8fd30cf30e9 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1059,7 +1059,9 @@ def mla_dsa_proj( """Token-wise projections for DSA MLA (CUDA-graph-capturable). Runs kv_a_proj, layernorms, q_b_proj, and conditionally - indexer.pre_indexer (which updates the indexer k cache). + indexer.pre_indexer_proj (FP8 quantize, weight scaling). Does NOT + update the indexer k cache — that happens in Op 2 (mla_dsa_attn_inplace) + because the scatter kernel accesses batch-specific metadata. Returns [q, compressed_kv, k_pe, latent_cache] when the short-MHA path handles all tokens, or [q, compressed_kv, k_pe, latent_cache, q_fp8, From 446a66a75b1a447e17a32205309f2070836ccce4 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 24 Mar 2026 07:41:34 -0700 Subject: [PATCH 03/12] [None][refactor] Move _update_k_cache into sparse_attn_indexer Move _update_k_cache call to the top of sparse_attn_indexer so the k cache is populated right before prefill chunks gather from it. Remove pre_indexer (now redundant); forward() and forward_dsa_proj both call pre_indexer_proj directly. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../_torch/attention_backend/sparse/dsa.py | 26 ++----------------- tensorrt_llm/_torch/modules/attention.py | 3 --- 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 3553587d268..ada5f5ed66d 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1499,10 +1499,6 @@ def sparse_attn_indexer( weights: torch.Tensor, use_custom_topk: bool = True, ) -> torch.Tensor: - """Run the indexer TopK kernel for both prefill and decode phases.""" - assert metadata.kv_cache_manager is None or \ - metadata.kv_cache_manager.quant_block_size == 128, \ - "Only support quant_block_size = 128 for now" # Update the indexer k cache before prefill chunks gather from it. self._update_k_cache(k_fp8, k_scale, metadata) @@ -1801,24 +1797,6 @@ def _prep_q_or_k(self, qk_pe: torch.Tensor, qk_nope: torch.Tensor): qk_pe, qk_nope, self.scale_fmt == "ue8m0") return fp8_out, scale - @torch.inference_mode() - def pre_indexer( - self, qr: torch.Tensor, hidden_states: torch.Tensor, - metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Token-wise projections, FP8 quantize, weight scaling, and k cache update. - - Runs the full indexer pre-computation including k cache update. - Used by the eager path (Indexer.forward) where everything runs - outside CUDA graph capture. - - Returns (q_fp8, k_fp8, k_scale, weights). - """ - q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj( - qr, hidden_states, position_ids) - self._update_k_cache(k_fp8, k_scale, metadata) - return q_fp8, k_fp8, k_scale, weights - def pre_indexer_proj( self, qr: torch.Tensor, hidden_states: torch.Tensor, position_ids: torch.Tensor @@ -1864,8 +1842,8 @@ def pre_indexer_proj( def forward(self, qr: torch.Tensor, hidden_states: torch.Tensor, metadata: DSAtrtllmAttentionMetadata, position_ids: torch.Tensor): - q_fp8, k_fp8, k_scale, weights = self.pre_indexer( - qr, hidden_states, metadata, position_ids) + q_fp8, k_fp8, k_scale, weights = self.pre_indexer_proj( + qr, hidden_states, position_ids) # Return topk indices buffer for sparse attention [num_tokens, index_topk] return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8fd30cf30e9..3f1f1e10062 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1886,9 +1886,6 @@ def forward_dsa_attn( k_fp8 = k_fp8[:num_tokens, ...] k_scale = k_scale[:num_tokens, ...] weights = weights[:num_tokens, ...] - # Update the indexer k cache here (outside CUDA graph) because - # it accesses batch-specific metadata (slot_mapping_fp8/scale). - self.mqa.indexer._update_k_cache(k_fp8, k_scale, attn_metadata) topk_indices = self.mqa.indexer.sparse_attn_indexer( attn_metadata, q, # only used for shape/device in buffer allocation From 358154135075cb1e9f03e1233998b6f8e74c76ef Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:36:58 -0700 Subject: [PATCH 04/12] [None][chore] Clean up MLA DSA custom op dispatch - Remove dead is_dsa branch from mla_custom_op_inplace since DSA is now exclusively handled by the split mla_dsa_proj/mla_dsa_attn_inplace ops - Use literal 1 for k_scale shape to match C++ fusedCatFp8 kernel output - Simplify proj_outputs unpacking Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 85 +----------------------- 1 file changed, 1 insertion(+), 84 deletions(-) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 3f1f1e10062..fd0714e25a2 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1050,86 +1050,6 @@ def mla_dsa_attn_inplace( output) -@torch.library.custom_op("trtllm::mla_dsa_proj", mutates_args=()) -def mla_dsa_proj( - hidden_states: torch.Tensor, - position_ids: Optional[torch.Tensor], - layer_idx: str, -) -> List[torch.Tensor]: - """Token-wise projections for DSA MLA (CUDA-graph-capturable). - - Runs kv_a_proj, layernorms, q_b_proj, and conditionally - indexer.pre_indexer_proj (FP8 quantize, weight scaling). Does NOT - update the indexer k cache — that happens in Op 2 (mla_dsa_attn_inplace) - because the scatter kernel accesses batch-specific metadata. - - Returns [q, compressed_kv, k_pe, latent_cache] when the short-MHA path - handles all tokens, or [q, compressed_kv, k_pe, latent_cache, q_fp8, - k_fp8, k_scale, weights] when the indexer runs. Under torch compile, - _should_use_short_mha returns False so the result is always length 8, - keeping control flow straight-line for CUDA graph capture. - """ - metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") - return mla_layer.forward_dsa_proj(position_ids, hidden_states, metadata) - - -@mla_dsa_proj.register_fake -def _mla_dsa_proj_fake( - hidden_states: torch.Tensor, - position_ids: Optional[torch.Tensor], - layer_idx: str, -) -> List[torch.Tensor]: - # Under torch compile _should_use_short_mha is False, so always 8 tensors. - metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") - num_tokens = hidden_states.shape[0] - indexer = mla_layer.mqa.indexer - q = hidden_states.new_empty( - [num_tokens, mla_layer.num_heads_tp * mla_layer.qk_head_dim]) - compressed_kv = hidden_states.new_empty( - [num_tokens, mla_layer.kv_lora_rank]) - k_pe = hidden_states.new_empty([num_tokens, mla_layer.qk_rope_head_dim]) - latent_cache = hidden_states.new_empty( - [num_tokens, mla_layer.kv_lora_rank + mla_layer.qk_rope_head_dim]) - # Indexer intermediates: q_fp8, k_fp8, k_scale, weights - q_fp8 = hidden_states.new_empty( - [num_tokens, indexer.n_heads, indexer.head_dim], - dtype=torch.float8_e4m3fn) - k_fp8 = hidden_states.new_empty([num_tokens, indexer.head_dim], - dtype=torch.float8_e4m3fn) - k_scale = hidden_states.new_empty([num_tokens, indexer.head_dim // 128], - dtype=torch.float32) - weights = hidden_states.new_empty([num_tokens, indexer.n_heads], - dtype=torch.float32) - return [ - q, compressed_kv, k_pe, latent_cache, q_fp8, k_fp8, k_scale, weights - ] - - -@torch.library.custom_op("trtllm::mla_dsa_attn_inplace", - mutates_args=("output", )) -def mla_dsa_attn_inplace( - q: torch.Tensor, - compressed_kv: torch.Tensor, - k_pe: torch.Tensor, - latent_cache: torch.Tensor, - indexer_intermediates: List[torch.Tensor], - position_ids: Optional[torch.Tensor], - layer_idx: str, - output: torch.Tensor, -) -> None: - """Batch-structure-dependent attention dispatch for DSA MLA. - - indexer_intermediates is [q_fp8, k_fp8, k_scale, weights] when the - indexer ran in Op 1, or [] when short-MHA handled all tokens. - Runs sparse_attn_indexer then dispatches context/generation attention. - This op is excluded from CUDA graph capture. - """ - metadata, mla_layer = extract_extra_attrs(layer_idx, "mla") - mla_layer.forward_dsa_attn(q, compressed_kv, k_pe, latent_cache, - indexer_intermediates, position_ids, metadata, - output) - - def fp8_block_scaling_bmm_out( mat1: torch.Tensor, mat2_fp8: torch.Tensor, @@ -2861,10 +2781,7 @@ def forward( if self.is_dsa: proj_outputs = torch.ops.trtllm.mla_dsa_proj( hidden_states, position_ids, self.layer_idx_str) - q, compressed_kv, k_pe, latent_cache = (proj_outputs[0], - proj_outputs[1], - proj_outputs[2], - proj_outputs[3]) + q, compressed_kv, k_pe, latent_cache = proj_outputs[:4] indexer_intermediates = proj_outputs[4:] torch.ops.trtllm.mla_dsa_attn_inplace( q, compressed_kv, k_pe, latent_cache, indexer_intermediates, From 4034805dfa7ccc552a3d2ce4806562917f5c44ed Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:54:33 -0700 Subject: [PATCH 05/12] [None][chore] Restore quant_block_size assertion in sparse_attn_indexer The assertion was dropped when the old Indexer.forward was split into pre_indexer_proj and sparse_attn_indexer. Restore it in sparse_attn_indexer which has access to metadata.kv_cache_manager. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index ada5f5ed66d..8fa4f6c5ed4 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1499,6 +1499,9 @@ def sparse_attn_indexer( weights: torch.Tensor, use_custom_topk: bool = True, ) -> torch.Tensor: + assert metadata.kv_cache_manager is None or \ + metadata.kv_cache_manager.quant_block_size == 128, \ + "Only support quant_block_size = 128 for now" # Update the indexer k cache before prefill chunks gather from it. self._update_k_cache(k_fp8, k_scale, metadata) From 4d30add0a47216b2454c588d9c109452637ed1bf Mon Sep 17 00:00:00 2001 From: Liao Lanyu <108499334+lancelly@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:38:36 +0800 Subject: [PATCH 06/12] [None][fix] Fix compute token accounting for KV cache reuse with context chunking (#12682) Signed-off-by: Liao Lanyu <108499334+lancelly@users.noreply.github.com> Co-authored-by: Liao Lanyu <108499334+lancelly@users.noreply.github.com> --- .../batch_manager/microBatchScheduler.cpp | 75 ++++++----- .../_torch/pyexecutor/model_engine.py | 18 +++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 125 +++++++++++++++++- .../_torch/pyexecutor/resource_manager.py | 76 ++++++++++- 4 files changed, 252 insertions(+), 42 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp index 40b760c3cb0..3e8fca0be05 100644 --- a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp @@ -24,6 +24,25 @@ namespace tensorrt_llm::batch_manager using SizeType32 = MicroBatchScheduler::SizeType32; +/// Return the forward-pass token cost for a context chunk with KV cache reuse. +/// +/// setPrepopulatedPromptLen shifts the chunk window right by the prepopulated +/// amount rather than shrinking it. For non-last chunks the model still +/// processes approximately @p chunkSize tokens; only for the last chunk is the +/// cost @p contextRemaining - @p reusable. +static SizeType32 reuse_adjusted_compute(SizeType32 chunkSize, SizeType32 reusable, SizeType32 contextRemaining) +{ + if (reusable <= 0) + { + return chunkSize; + } + if (reusable + chunkSize < contextRemaining) + { + return chunkSize; + } + return std::max(0, contextRemaining - reusable); +} + MicroBatchScheduler::MicroBatchScheduler(std::optional ctxChunkConfig, std::optional maxContextLength, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) @@ -38,16 +57,15 @@ void MicroBatchScheduler::fitDraftTokens(RequestVector& contextsToBeChunked, std::optional ctxTokensCapacity, SizeType32 const chunkUnitSize, std::optional const& maxContextLength) { - // How many compute tokens (chunk - reusable) are in this batch already? + // How many compute tokens are in this batch already? SizeType32 numCtxTokens{0}; for (auto const& llmReq : contextsToBeChunked) { SizeType32 const chunkSize = llmReq->getContextChunkSize(); - // contextRemaining = P for first chunk; used to compute actual model token count. SizeType32 const contextRemaining = llmReq->getContextRemainingLength(); SizeType32 const reusable = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0; - numCtxTokens += std::min(chunkSize, std::max(0, contextRemaining - reusable)); + numCtxTokens += reuse_adjusted_compute(chunkSize, reusable, contextRemaining); } // Discard draft tokens that won't fit into the existing chunk unit, max @@ -125,14 +143,13 @@ void MicroBatchScheduler::setCtxRequestsChunkSizegetContextChunkSize(); SizeType32 actualIncrement = actualChunkSize - pastChunkSize; - // Compute-aware budget: reusable tokens are served from cache and do not - // consume forward-pass capacity. Only the tokens beyond the reusable prefix count. - SizeType32 const reusable = llmReq->isFirstContextChunk() - ? std::min(llmReq->getEstimatedReusableTokens(), llmReq->getContextRemainingLength()) - : 0; - SizeType32 const pastCompute = std::max(0, pastChunkSize - std::min(reusable, pastChunkSize)); - SizeType32 const actualCompute - = std::max(0, actualChunkSize - std::min(reusable, actualChunkSize)); + // Compute-aware budget accounting for setPrepopulatedPromptLen's + // chunk-shift behaviour (non-last chunks keep their full size). + SizeType32 const contextRemaining = llmReq->getContextRemainingLength(); + SizeType32 const reusable + = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0; + SizeType32 const pastCompute = reuse_adjusted_compute(pastChunkSize, reusable, contextRemaining); + SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, contextRemaining); SizeType32 const computeIncrement = actualCompute - pastCompute; if ((ctxTokensCapacity && numCtxTokens + computeIncrement > ctxTokensCapacity.value()) @@ -181,24 +198,21 @@ void MicroBatchScheduler::setCtxRequestsChunkSizegetContextRemainingLength(); - // Reusable tokens are "free" — they don't consume forward-pass compute budget. SizeType32 const reusable = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), suggestedChunkSize) : 0; - SizeType32 const computeCost = suggestedChunkSize - reusable; + SizeType32 const computeCost = reuse_adjusted_compute(suggestedChunkSize, reusable, suggestedChunkSize); SizeType32 actualChunkSize = suggestedChunkSize; if (ctxTokensCapacity && computeCost > ctxTokensCapacity.value()) { - // Model processes min(chunk_size, P - reusable) tokens starting from position reusable. - // To keep model tokens within budget: chunk_size <= capacity (not reusable + capacity). actualChunkSize = ctxTokensCapacity.value(); } if (maxContextLength) { - // maxContextLength limits compute tokens, not total tokens. - SizeType32 const actualCompute = std::max(0, actualChunkSize - reusable); + SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize); if (actualCompute > maxContextLength.value()) { - actualChunkSize = std::min(reusable + maxContextLength.value(), suggestedChunkSize); + actualChunkSize = maxContextLength.value(); + actualChunkSize = std::min(actualChunkSize, suggestedChunkSize); } } if (actualChunkSize != suggestedChunkSize) @@ -208,10 +222,7 @@ void MicroBatchScheduler::setCtxRequestsChunkSizesetContextChunkSize(actualChunkSize); if (ctxTokensCapacity) { - // Decrement by actual model token count: min(chunk_size, P - reusable). - // This equals min(actualChunkSize, computeCost) since computeCost = suggestedChunkSize - reusable. - SizeType32 const modelCost - = std::min(actualChunkSize, std::max(0, suggestedChunkSize - reusable)); + SizeType32 const modelCost = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize); ctxTokensCapacity = ctxTokensCapacity.value() - modelCost; } } @@ -349,10 +360,12 @@ std::tuple MicroBatchScheduler::operator()(Request if (!mCtxChunkConfig) // skip chunking { constexpr SizeType32 beam{0}; - reqNumTokens - = llmReq->getNumTokens(beam) + (llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0); - // Compute tokens = total - reusable (at least 1 to make progress) - SizeType32 const computeTokens = std::max(1, reqNumTokens - reusable); + SizeType32 const contextTokens = llmReq->getNumTokens(beam); + SizeType32 const draftTokens = llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0; + reqNumTokens = contextTokens + draftTokens; + SizeType32 const contextCompute + = reuse_adjusted_compute(contextTokens, reusable, llmReq->getContextRemainingLength()); + SizeType32 const computeTokens = std::max(1, contextCompute + draftTokens); TLLM_CHECK_WITH_INFO(!mMaxContextLength || computeTokens <= mMaxContextLength.value(), "Context compute tokens (%d) exceeds the limit value (%d)", computeTokens, mMaxContextLength.value()); @@ -369,9 +382,8 @@ std::tuple MicroBatchScheduler::operator()(Request llmReq->setContextChunkSize(llmReq->getContextRemainingLength()); auto const draftTokens = (llmReq->isLastContextChunk() && llmReq->hasDraftTokens()) ? llmReq->getNumDraftTokens() : 0; - // Compute cost: context compute + draft tokens - // (reusable tokens only offset context tokens, not draft tokens) - SizeType32 const contextCompute = std::max(0, llmReq->getContextChunkSize() - reusable); + SizeType32 const contextCompute = reuse_adjusted_compute( + llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength()); SizeType32 computeTokens = contextCompute + draftTokens; if (mMaxContextLength) @@ -444,10 +456,9 @@ std::tuple MicroBatchScheduler::operator()(Request if (llmReq->getContextChunkSize() > 0) { contextRequests.emplace_back(llmReq); - // Only count compute tokens (total - reusable). - // Reusable credit only applies to the first context chunk. SizeType32 const reusable = llmReq->isFirstContextChunk() ? llmReq->getEstimatedReusableTokens() : 0; - SizeType32 const computeTokens = std::max(0, llmReq->getContextChunkSize() - reusable); + SizeType32 const computeTokens + = reuse_adjusted_compute(llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength()); batchNumTokens += computeTokens; TLLM_LOG_DEBUG("context request scheduled: ID %lu, chunk size %d%s", llmReq->mRequestId, llmReq->getContextChunkSize(), reusable > 0 ? (", reusable " + std::to_string(reusable)).c_str() : ""); diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7c2ce12a7f4..a8e35cb77ad 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2636,6 +2636,24 @@ def previous_seq_slots_device(): num_tokens = len(input_ids) num_draft_tokens = len(draft_tokens) total_num_tokens = len(position_ids) + if total_num_tokens > self.max_num_tokens: + ctx_details = [] + for r in scheduled_requests.context_requests: + pos = r.context_current_position + csz = r.context_chunk_size + full = len(r.get_tokens(0)) + tokens = min(csz, max(0, full - pos)) + ctx_details.append( + f"rid={r.py_request_id} pos={pos} chunk={csz} " + f"full={full} tokens={tokens}") + gen_count = len(scheduled_requests.generation_requests) + from tensorrt_llm.logger import logger as _mnt_logger + _mnt_logger.error( + f"MNT overflow: total={total_num_tokens} " + f"max={self.max_num_tokens} " + f"ctx_reqs={len(scheduled_requests.context_requests)} " + f"gen_reqs={gen_count} " + f"ctx_breakdown=[{'; '.join(ctx_details)}]") assert total_num_tokens <= self.max_num_tokens, ( f"total_num_tokens ({total_num_tokens}) should be less than or equal to max_num_tokens ({self.max_num_tokens})" ) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 34012dde254..b5d0aaa1085 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2823,6 +2823,116 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest], balanced_context_requests = context_requests return balanced_context_requests + @staticmethod + def _compute_scheduled_tokens(context_requests, generation_requests): + """Compute the total number of scheduled tokens for batch waiting decisions. + + For context requests, we estimate the actual compute tokens for this + iteration (excluding tokens served from KV cache). + + For generation requests, each contributes 1 + num_draft_tokens. + + Note on reusable token handling: + estimated_reusable_tokens is an absolute count from position 0. + Depending on the scheduler, context_current_position may or may not + have been advanced past the reusable prefix by the time this method + is called: + - V1 scheduler: prepare_context runs after scheduling, so + context_current_position is still 0. + - V2 scheduler: prepare_context runs during scheduling, so + context_current_position is already advanced to the reused offset. + To handle both correctly, the reusable credit applied to the current + chunk is max(0, reusable - context_current_position), i.e. only the + portion of the reusable range that falls within this chunk's span. + """ + num_scheduled_ctx_tokens = 0 + for ctx_req in context_requests: + reusable = (ctx_req.estimated_reusable_tokens + if ctx_req.is_first_context_chunk else 0) + # Credit only the reusable tokens that overlap with the current + # chunk: if context_current_position has already been advanced past + # the reusable prefix (V2), the credit is 0; if not (V1), the full + # reusable count is subtracted. + reusable_in_chunk = max(0, + reusable - ctx_req.context_current_position) + remaining = ctx_req.context_remaining_length + if (reusable_in_chunk > 0 and + reusable_in_chunk + ctx_req.context_chunk_size < remaining): + compute = ctx_req.context_chunk_size + else: + compute = max(1, ctx_req.context_chunk_size - reusable_in_chunk) + num_scheduled_ctx_tokens += compute + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + return num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + def _maybe_log_batch_wait_decision( + self, + context_requests: list[LlmRequest], + generation_requests: list[LlmRequest], + num_scheduled_tokens: int, + wait_threshold: float, + should_waiting: bool, + ) -> None: + """Diagnostics for batch_wait: set TLLM_LOG_BATCH_WAIT=1 (rank 0 only).""" + if self.dist.rank != 0: + return + + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + num_scheduled_ctx_formula = num_scheduled_tokens - num_scheduled_gen_tokens + + chunk_ctx_sum = 0 + ctx_summaries: List[str] = [] + max_detail = 4 + for i, ctx_req in enumerate(context_requests): + full_len = len(ctx_req.get_tokens(0)) + begin = ctx_req.context_current_position + chunk_sz = ctx_req.context_chunk_size + this_chunk = min(chunk_sz, max(0, full_len - begin)) + chunk_ctx_sum += this_chunk + reusable = (ctx_req.estimated_reusable_tokens + if ctx_req.is_first_context_chunk else 0) + reusable_in_chunk = max(0, reusable - begin) + remaining = ctx_req.context_remaining_length + if (reusable_in_chunk > 0 + and reusable_in_chunk + chunk_sz < remaining): + formula_contrib = chunk_sz + else: + formula_contrib = max(1, chunk_sz - reusable_in_chunk) + if i < max_detail: + ctx_summaries.append( + f"rid={ctx_req.py_request_id} full={full_len} pos={begin} " + f"chunk_sz={chunk_sz} this_chunk={this_chunk} " + f"reusable={reusable} formula_contrib={formula_contrib}") + n_ctx = len(context_requests) + if n_ctx > max_detail: + ctx_summaries.append(f"... +{n_ctx - max_detail} more ctx req(s)") + + logger.info( + "batch_wait: formula_total=", + num_scheduled_tokens, + " formula_ctx=", + num_scheduled_ctx_formula, + " formula_gen=", + num_scheduled_gen_tokens, + " chunk_ctx_sum=", + chunk_ctx_sum, + " threshold=", + wait_threshold, + " wait_iter=", + self.batch_wait_iters_count, + "/", + self.batch_wait_timeout_iters, + " should_defer_ctx=", + should_waiting, + " num_gen=", + len(generation_requests), + " ctx_detail=[", + "; ".join(ctx_summaries), + "]", + ) + def _waiting_requests(self, context_requests: list[LlmRequest], generation_requests: list[LlmRequest]): """ @@ -2832,13 +2942,16 @@ def _waiting_requests(self, context_requests: list[LlmRequest], - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`. """ - num_scheduled_ctx_tokens = sum( - len(ctx_req.get_tokens(0)) for ctx_req in context_requests) - num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens - for gen_req in generation_requests) - num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + num_scheduled_tokens = self._compute_scheduled_tokens( + context_requests, generation_requests) + wait_threshold = (self.batch_wait_max_tokens_ratio * + self.max_num_tokens) - should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens + should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < wait_threshold + self._maybe_log_batch_wait_decision(context_requests, + generation_requests, + num_scheduled_tokens, + wait_threshold, should_waiting) if should_waiting: self.batch_wait_iters_count += 1 return [] diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3c0fa9e6011..344d3913598 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -619,6 +619,22 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() + # Pre-addSequence budget re-validation. The C++ scheduler + # should already account for the chunk-shift cost, but under + # heavy KV-cache eviction the actual reuse may be lower than + # estimated. We re-probe the radix tree and estimate the + # true forward cost; if it exceeds the remaining budget the + # request is skipped (re-scheduled next iteration). + remaining_budget = None + if self.enable_block_reuse and not self.is_draft: + gen_tokens = sum( + req.get_beam_width_by_iter(for_next_iteration=False) + + get_draft_token_length(req) + for req in scheduled_batch.generation_requests) + remaining_budget = self.max_num_tokens - gen_tokens + + accepted_ctx_requests = [] + # allocate KV Cache for req in scheduled_batch.context_requests: req_beam_width = req.sampling_config.beam_width @@ -635,9 +651,37 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( req): - self.impl.add_sequence(req.py_request_id, - req.prompt_len, req_beam_width, - req) + if remaining_budget is not None: + unique_tokens = req.get_unique_tokens(0) + reusable_blocks = self.impl.count_reusable_blocks( + unique_tokens, req, False) + actual_reuse = (reusable_blocks * + self.tokens_per_block) + req_compute = self._estimate_post_reuse_compute( + actual_reuse, req.context_chunk_size, + req.prompt_len) + if req_compute > remaining_budget: + logger.warning( + f"Reuse budget: skip req " + f"{req.py_request_id} " + f"(compute={req_compute}, " + f"chunk={req.context_chunk_size}, " + f"reuse={actual_reuse}, " + f"remaining={remaining_budget})") + continue + remaining_budget -= req_compute + + try: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, + req_beam_width, req) + except RuntimeError: + logger.warning( + f"add_sequence: req " + f"{req.py_request_id} already exists, " + f"skipping") + accepted_ctx_requests.append(req) + continue for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): @@ -647,9 +691,16 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) + elif remaining_budget is not None: + reusable = (req.estimated_reusable_tokens + if req.is_first_context_chunk else 0) + remaining_budget -= self._estimate_post_reuse_compute( + reusable, req.context_chunk_size, req.prompt_len) + + accepted_ctx_requests.append(req) # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in `add_sequence` due to KV cache reuse, so we rebuild the context request lists here. - scheduled_batch.reset_context_requests() + scheduled_batch.reset_context_requests(accepted_ctx_requests) for req in scheduled_batch.generation_requests: if self.mapping.has_cp_helix(): @@ -674,6 +725,23 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.kv_connector_manager.build_scheduler_output( scheduled_batch, self) + def _estimate_post_reuse_compute(self, reuse_tokens: int, chunk_size: int, + prompt_len: int) -> int: + """Estimate forward compute tokens after setPrepopulatedPromptLen. + + For non-last chunks the chunk window shifts right by the reused + amount and the forward cost is approximately chunk_size. For + last chunks the cost is prompt_len - reuse (original formula). + """ + P = reuse_tokens + if P > 0 and P < prompt_len: + if P + chunk_size < prompt_len: + aligned_end = ((P + chunk_size) // self.tokens_per_block * + self.tokens_per_block) + return max(1, aligned_end - P) + return max(1, prompt_len - P) + return chunk_size + def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool: return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence( request) From efec0367d77af3aca58123775b3f22d3da634ddc Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Thu, 2 Apr 2026 18:04:35 +0800 Subject: [PATCH 07/12] [https://nvbugs/5983390][perf] Cherry-pick #12581: Multiple host perf optimizations for DSA part (#12681) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 8fa4f6c5ed4..f11e9a35289 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -688,7 +688,6 @@ def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens): def prepare_dense_topk_indices(self, kv_lens, device=False): # device=False means use CPU - """Prepare dense TopK indices for short sequences that skip the indexer.""" if self.num_contexts > 0 and self.skip_indexer_for_ctx_reqs: ctx_range = slice(self.num_ctx_tokens) From 7596a135495df71f894897a136d930f4a281c7f9 Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Thu, 2 Apr 2026 22:30:53 -0700 Subject: [PATCH 08/12] [https://nvbugs/5983390][perf] Reduce host overhead in DSA MLA attention path (#12691) Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- cpp/tensorrt_llm/nanobind/thop/bindings.cpp | 4 +- .../thop/IndexerKCacheScatterOp.cpp | 101 +++++++++--------- cpp/tensorrt_llm/thop/attentionOp.cpp | 18 +--- cpp/tensorrt_llm/thop/attentionOp.h | 2 +- .../_torch/attention_backend/sparse/dsa.py | 32 ++---- .../_torch/attention_backend/trtllm.py | 10 +- .../_torch/attention_backend/trtllm_gen.py | 34 +----- .../custom_ops/attention/trtllm_attention.py | 9 ++ .../attention/sparse/test_dsa_indexer.py | 11 +- 9 files changed, 92 insertions(+), 129 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index fc161ab4a6c..b71c39d4087 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -70,8 +70,8 @@ void initBindings(nb::module_& m) nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt, - nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation", - nb::call_guard()); + nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0, + "Multi-head attention operation", nb::call_guard()); m.def( "get_helix_workspace_size_per_rank", diff --git a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp index 940d59258ca..f5a1336ea3e 100644 --- a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp +++ b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp @@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN namespace torch_ext { -void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache, - th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale) +void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache, + th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens) { - // Validate all tensors are CUDA tensors - TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() + // k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8 + // k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes + // slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used + // k_cache: [num_blocks, block_size, 1, per_token_size] uint8 + + TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() && slot_mapping_scale.is_cuda(), "All tensors must be CUDA tensors"); // Validate tensor dimensions - TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]"); - TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]"); - TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]"); - TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]"); - - // Enforce k_cache is 4D tensor - TORCH_CHECK(k_cache.dim() == 4, - "k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions", + TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]"); + TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]"); + TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]"); + TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]"); + TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims", static_cast(k_cache.dim())); - // Validate tensor dtypes - TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8"); - TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8"); + // Validate tensor dtypes — reinterpret_cast below assumes specific element sizes + TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size()); + TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d", + k_scale.element_size()); TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64"); TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64"); - // Validate tensor shapes are consistent - auto num_tokens = static_cast(k_fp8_bytes.size(0)); - TORCH_CHECK( - k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension"); - TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens"); - TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens"); - - // Validate tensors are contiguous (except k_cache which may be non-contiguous) - TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous"); - TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous"); - // k_cache can be non-contiguous - we handle this via strides + TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous"); + TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous"); TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous"); TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous"); - int32_t head_dim = static_cast(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128 - int32_t scale_size = static_cast(k_scale_bytes.size(1)); // scale_size = 4 bytes - - int32_t cache_dim_0 = static_cast(k_cache.size(0)); // num_blocks - int32_t cache_dim_1 = static_cast(k_cache.size(1)); // block_size - int32_t cache_dim_2 = static_cast(k_cache.size(2)); // num_kv_heads - int32_t cache_dim_3 = static_cast(k_cache.size(3)); // per_token_size - - // Validation for indexer k cache pool for DeepSeek-V3.2 constraints - TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2); - TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim); - TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size); - - int64_t cache_stride_0 = static_cast(k_cache.stride(0)); - int64_t cache_stride_1 = static_cast(k_cache.stride(1)); - int64_t cache_stride_2 = static_cast(k_cache.stride(2)); - int64_t cache_stride_3 = static_cast(k_cache.stride(3)); - - auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device()); - - tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr(), k_scale_bytes.data_ptr(), - k_cache.data_ptr(), slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), - num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, - cache_stride_1, cache_stride_2, cache_stride_3, stream); + // FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes. + int32_t const head_dim = static_cast(k_fp8.size(1)); + // Scale size in bytes: num_scale_elements * bytes_per_element. + int32_t const scale_size = static_cast(k_scale.size(1)) * static_cast(k_scale.element_size()); + + int32_t const cache_dim_0 = static_cast(k_cache.size(0)); + int32_t const cache_dim_1 = static_cast(k_cache.size(1)); + int32_t const cache_dim_2 = static_cast(k_cache.size(2)); + int32_t const cache_dim_3 = static_cast(k_cache.size(3)); + + TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2); + TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim); + TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size); + + int64_t const cache_stride_0 = static_cast(k_cache.stride(0)); + int64_t const cache_stride_1 = static_cast(k_cache.stride(1)); + int64_t const cache_stride_2 = static_cast(k_cache.stride(2)); + int64_t const cache_stride_3 = static_cast(k_cache.stride(3)); + + auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device()); + + // Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr. + // For slot mappings, use data_ptr directly — only the first num_tokens entries are read. + tk::invokeIndexerKCacheScatter(reinterpret_cast(k_fp8.data_ptr()), + reinterpret_cast(k_scale.data_ptr()), k_cache.data_ptr(), + slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), static_cast(num_tokens), + head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1, + cache_stride_2, cache_stride_3, stream); } } // namespace torch_ext @@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, " - "Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()"); + "indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, " + "Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 9a7af4da49f..b526310564e 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional k, std::optional cu_q_seqlens, std::optional cu_kv_seqlens, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, - std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits) + std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, + int64_t num_contexts, int64_t num_ctx_tokens) { TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -833,20 +834,9 @@ void attention(torch::Tensor q, std::optional k, std::optional(num_contexts); int32_t const num_tokens = qkv_or_q.size(0); - int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item(); - int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens; + int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - static_cast(num_ctx_tokens); auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item(); auto const gen_total_kv_len = host_total_kv_lens.index({1}).item(); diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h index 0fc4788d6f0..854de818292 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.h +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -78,7 +78,7 @@ void attention(torch::Tensor q, std::optional k, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata = std::nullopt, - std::optional flash_mla_num_splits = std::nullopt); + std::optional flash_mla_num_splits = std::nullopt, int64_t num_contexts, int64_t num_ctx_tokens); struct KvCachePoolPointers { diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index f11e9a35289..614ed98e6c8 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1459,34 +1459,18 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None: return - # [num_blocks, block_size, 1, per_token_size ] k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( self.layer_idx) num_tokens = k_fp8.shape[0] - head_dim = k_fp8.shape[1] - scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes) - - # Convert to bytes: flatten first, then view as uint8, then reshape - k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view( - num_tokens, head_dim) - - # k_scale: for single-element tensors, contiguous() may be no-op - # Fix stride(-1) for byte-level view - k_scale_flat = k_scale.view(-1) - if k_scale_flat.stride(-1) != 1: - k_scale_flat = torch.as_strided(k_scale_flat.contiguous(), - size=(k_scale_flat.numel(), ), - stride=(1, )) - k_scale_bytes = k_scale_flat.view(torch.uint8).view( - num_tokens, scale_size) - - # Use CUDA kernel to scatter FP8 and scale bytes into cache - flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens] - flat_indices_scale = metadata.slot_mapping_scale[:num_tokens] - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache, flat_indices_fp8, - flat_indices_scale) + + # The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw + # bytes internally and only reads the first num_tokens entries from + # the slot mapping buffers, avoiding Python-side view/slice overhead. + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) def sparse_attn_indexer( self, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 7f1b7c0c4df..b1da0f66fe0 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -407,6 +407,8 @@ def run( mla_bmm1_scale: Optional[torch.Tensor] = None, mla_bmm2_scale: Optional[torch.Tensor] = None, quant_q_buffer: Optional[torch.Tensor] = None, + num_contexts: int = 0, + num_ctx_tokens: int = 0, ): """ Run the attention operation. @@ -638,6 +640,8 @@ def run( quant_q_buffer, self.quant_config, self.kv_cache_manager, + num_contexts, + num_ctx_tokens, global_layer_idx=self.global_layer_idx, ) else: @@ -722,6 +726,8 @@ def run( quant_q_buffer, self.flash_mla_tile_scheduler_metadata, self.flash_mla_num_splits, + num_contexts, + num_ctx_tokens, ) if self.print_skip_softmax_stat: @@ -2049,7 +2055,9 @@ def forward( fmha_scheduler_counter=fmha_scheduler_counter, mla_bmm1_scale=mla_bmm1_scale, mla_bmm2_scale=mla_bmm2_scale, - quant_q_buffer=quant_q_buffer) + quant_q_buffer=quant_q_buffer, + num_contexts=metadata.num_contexts, + num_ctx_tokens=metadata.num_ctx_tokens) if output_sf is None: return output diff --git a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py index 439831d4bf7..ff0ac200401 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py @@ -1437,23 +1437,6 @@ def run_mla_generation(self, params: EnqueueGenerationParams) -> None: params.context_buf.copy_(mla_out.reshape_as(params.context_buf)) -def _parse_request_types(host_request_types: torch.Tensor) -> Tuple[int, int]: - """ - Parse request types to count context and generation requests. - - Args: - host_request_types: Request types tensor (0=context, 1=generation). - num_seqs: Total number of sequences. - - Returns: - Tuple of (num_contexts, num_generations). - """ - - num_generations = host_request_types.sum().item() - num_contexts = host_request_types.size(0) - num_generations - return num_contexts, num_generations - - def is_supported( q: torch.Tensor, num_heads: int, @@ -1636,6 +1619,8 @@ def trtllm_gen_attention( quant_q_buffer: Optional[torch.Tensor], quant_config: Optional[QuantConfig], kv_cache_manager: Optional[KVCacheManager], + num_contexts: int, + num_ctx_tokens: int, global_layer_idx: Optional[int] = None, ) -> None: """ @@ -1766,20 +1751,9 @@ def trtllm_gen_attention( if attention_input_type is not None: attn_input_type = AttentionInputType(attention_input_type) - num_contexts, num_generations = _parse_request_types(host_request_types) - is_gen_only = attn_input_type == AttentionInputType.generation_only - is_ctx_only = attn_input_type == AttentionInputType.context_only - - if is_gen_only: - num_ctx_tokens = 0 - num_gen_tokens = num_tokens - elif is_ctx_only: - num_ctx_tokens = num_tokens - num_gen_tokens = 0 - else: - num_ctx_tokens = int(host_context_lengths[:num_contexts].sum()) if num_contexts > 0 else 0 - num_gen_tokens = num_tokens - num_ctx_tokens + num_generations = host_request_types.size(0) - num_contexts + num_gen_tokens = num_tokens - num_ctx_tokens # Prepare Workspace # Use upper-bound token counts for workspace sizing to avoid repeated diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index f1c99267ed0..5d7895fd0f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -82,6 +82,9 @@ def __init__(self): # keeping a separate copy here since we sometimes have to overwrite the original values self.host_past_kv_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned self.host_context_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned + # Batch counts for thop.attention (updated every forward in plan_host) + self.num_contexts: int = 0 + self.num_ctx_tokens: int = 0 # Persistent block_offsets buffer for CUDA graph compatibility. # Pre-allocated to max size so the tensor address is stable across replays. self.block_offsets: Optional[torch.Tensor] = None @@ -171,6 +174,10 @@ def plan_host( """ num_seq = num_prefill + num_decode + # Batch counts for thop.attention + self.num_contexts = num_prefill + self.num_ctx_tokens = int(seq_len_host[:num_prefill].sum()) if num_prefill > 0 else 0 + # host_request_types: 0 = prefill (context), 1 = decode (generation) self.host_request_types[:num_prefill].fill_(0) self.host_request_types[num_prefill:num_seq].fill_(1) @@ -500,6 +507,8 @@ def trtllm_mha_with_cache( None, # mla_bmm1_scale None, # mla_bmm2_scale None, # quant_q_buffer + num_contexts=_GlobalTrtllmPlanner.num_contexts, + num_ctx_tokens=_GlobalTrtllmPlanner.num_ctx_tokens, ) if out is not None: diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index b1054a27dd4..0bf2506885a 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -716,7 +716,7 @@ def test_indexer_k_cache_scatter_custom_op(): dtype=torch.bfloat16) k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original) - # Prepare byte-level data + # Prepare byte-level data for the Python reference path scale_size = k_scale.shape[1] * 4 k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim) k_scale_flat = k_scale.view(-1) @@ -754,10 +754,11 @@ def test_indexer_k_cache_scatter_custom_op(): print(f" is_contiguous: {k_cache_python.is_contiguous()}") # ========== Path 1: CUDA Kernel ========== - print("\n=== Path 1: CUDA Kernel ===") - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache_cuda, flat_indices_fp8, - flat_indices_scale) + print(f"\n=== Path 1: CUDA Kernel ===") + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache_cuda, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) torch.cuda.synchronize() print("✓ CUDA kernel completed") From d39e03be8de589c0818cbc1562311a46328719d7 Mon Sep 17 00:00:00 2001 From: v-shobhit <161510941+v-shobhit@users.noreply.github.com> Date: Sun, 5 Apr 2026 17:27:59 -0700 Subject: [PATCH 09/12] Add default values for num_contexts and num_ctx_tokens Signed-off-by: v-shobhit <161510941+v-shobhit@users.noreply.github.com> --- cpp/tensorrt_llm/thop/attentionOp.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h index 854de818292..77f9a965e48 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.h +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -78,7 +78,9 @@ void attention(torch::Tensor q, std::optional k, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata = std::nullopt, - std::optional flash_mla_num_splits = std::nullopt, int64_t num_contexts, int64_t num_ctx_tokens); + std::optional flash_mla_num_splits = std::nullopt, + int64_t num_contexts = 0, + int64_t num_ctx_tokens = 0); struct KvCachePoolPointers { From 3d59f85a54f6bc90bb94060be137c75e1cb1e9cb Mon Sep 17 00:00:00 2001 From: Jin Li <59594262+liji-nv@users.noreply.github.com> Date: Tue, 7 Apr 2026 06:12:42 -0700 Subject: [PATCH 10/12] =?UTF-8?q?[None][fix]=20Pre-subtract=20non-first-ch?= =?UTF-8?q?unk=20context=20costs=20in=20reuse=20budge=E2=80=A6=20(#12806)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 344d3913598..6b368eabc39 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -633,6 +633,14 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): for req in scheduled_batch.generation_requests) remaining_budget = self.max_num_tokens - gen_tokens + # Pre-subtract the fixed cost of non-first-chunk context + # requests. These have no reuse to re-validate and their + # compute cost is committed, so first-chunk budget checks + # must see the budget with these costs already removed. + for req in scheduled_batch.context_requests: + if not req.is_first_context_chunk: + remaining_budget -= req.context_chunk_size + accepted_ctx_requests = [] # allocate KV Cache @@ -691,9 +699,13 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) - elif remaining_budget is not None: - reusable = (req.estimated_reusable_tokens - if req.is_first_context_chunk else 0) + elif remaining_budget is not None and req.is_first_context_chunk: + # First-chunk request that skipped add_sequence + # (e.g. kv_connector said not to). Subtract its + # estimated cost so later first-chunk checks see a + # correct budget. Non-first-chunk costs were + # already pre-subtracted above. + reusable = req.estimated_reusable_tokens remaining_budget -= self._estimate_post_reuse_compute( reusable, req.context_chunk_size, req.prompt_len) From fba201080a1240ad396db2696450a66a359e38a1 Mon Sep 17 00:00:00 2001 From: Dongfeng Yu Date: Thu, 9 Apr 2026 00:30:17 +0000 Subject: [PATCH 11/12] fix comments Signed-off-by: Dongfeng Yu --- tensorrt_llm/_torch/attention_backend/sparse/dsa.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 614ed98e6c8..1e63c3c111e 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -688,6 +688,7 @@ def _get_dense_topk_indices(self, seq_lens, kv_lens, num_tokens): def prepare_dense_topk_indices(self, kv_lens, device=False): # device=False means use CPU + """Prepare dense TopK indices for short sequences that skip the indexer.""" if self.num_contexts > 0 and self.skip_indexer_for_ctx_reqs: ctx_range = slice(self.num_ctx_tokens) @@ -1459,6 +1460,7 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None: return + # [num_blocks, block_size, 1, per_token_size ] k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( self.layer_idx) @@ -1482,6 +1484,7 @@ def sparse_attn_indexer( weights: torch.Tensor, use_custom_topk: bool = True, ) -> torch.Tensor: + """Run the indexer TopK kernel for both prefill and decode phases.""" assert metadata.kv_cache_manager is None or \ metadata.kv_cache_manager.quant_block_size == 128, \ "Only support quant_block_size = 128 for now" @@ -1788,6 +1791,7 @@ def pre_indexer_proj( position_ids: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Pure token-wise projections (CUDA-graph-capturable). + Runs cublas_mm, qk_projection_and_rope, FP8 quantize, and weight scaling. Does NOT touch the k cache or any batch-specific metadata, so this can safely run inside a captured CUDA graph partition. From cb206e6b0eb123b9d0c4aa8330e7659e40c3b259 Mon Sep 17 00:00:00 2001 From: Dongfeng Yu Date: Thu, 9 Apr 2026 00:51:43 +0000 Subject: [PATCH 12/12] fix comments Signed-off-by: Dongfeng Yu --- cpp/tensorrt_llm/thop/attentionOp.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h index 77f9a965e48..cc2b3f787f0 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.h +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -78,8 +78,7 @@ void attention(torch::Tensor q, std::optional k, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata = std::nullopt, - std::optional flash_mla_num_splits = std::nullopt, - int64_t num_contexts = 0, + std::optional flash_mla_num_splits = std::nullopt, int64_t num_contexts = 0, int64_t num_ctx_tokens = 0); struct KvCachePoolPointers