diff --git a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp index 40b760c3cb0..3e8fca0be05 100644 --- a/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp @@ -24,6 +24,25 @@ namespace tensorrt_llm::batch_manager using SizeType32 = MicroBatchScheduler::SizeType32; +/// Return the forward-pass token cost for a context chunk with KV cache reuse. +/// +/// setPrepopulatedPromptLen shifts the chunk window right by the prepopulated +/// amount rather than shrinking it. For non-last chunks the model still +/// processes approximately @p chunkSize tokens; only for the last chunk is the +/// cost @p contextRemaining - @p reusable. +static SizeType32 reuse_adjusted_compute(SizeType32 chunkSize, SizeType32 reusable, SizeType32 contextRemaining) +{ + if (reusable <= 0) + { + return chunkSize; + } + if (reusable + chunkSize < contextRemaining) + { + return chunkSize; + } + return std::max(0, contextRemaining - reusable); +} + MicroBatchScheduler::MicroBatchScheduler(std::optional ctxChunkConfig, std::optional maxContextLength, LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState) @@ -38,16 +57,15 @@ void MicroBatchScheduler::fitDraftTokens(RequestVector& contextsToBeChunked, std::optional ctxTokensCapacity, SizeType32 const chunkUnitSize, std::optional const& maxContextLength) { - // How many compute tokens (chunk - reusable) are in this batch already? + // How many compute tokens are in this batch already? SizeType32 numCtxTokens{0}; for (auto const& llmReq : contextsToBeChunked) { SizeType32 const chunkSize = llmReq->getContextChunkSize(); - // contextRemaining = P for first chunk; used to compute actual model token count. SizeType32 const contextRemaining = llmReq->getContextRemainingLength(); SizeType32 const reusable = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0; - numCtxTokens += std::min(chunkSize, std::max(0, contextRemaining - reusable)); + numCtxTokens += reuse_adjusted_compute(chunkSize, reusable, contextRemaining); } // Discard draft tokens that won't fit into the existing chunk unit, max @@ -125,14 +143,13 @@ void MicroBatchScheduler::setCtxRequestsChunkSizegetContextChunkSize(); SizeType32 actualIncrement = actualChunkSize - pastChunkSize; - // Compute-aware budget: reusable tokens are served from cache and do not - // consume forward-pass capacity. Only the tokens beyond the reusable prefix count. - SizeType32 const reusable = llmReq->isFirstContextChunk() - ? std::min(llmReq->getEstimatedReusableTokens(), llmReq->getContextRemainingLength()) - : 0; - SizeType32 const pastCompute = std::max(0, pastChunkSize - std::min(reusable, pastChunkSize)); - SizeType32 const actualCompute - = std::max(0, actualChunkSize - std::min(reusable, actualChunkSize)); + // Compute-aware budget accounting for setPrepopulatedPromptLen's + // chunk-shift behaviour (non-last chunks keep their full size). + SizeType32 const contextRemaining = llmReq->getContextRemainingLength(); + SizeType32 const reusable + = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0; + SizeType32 const pastCompute = reuse_adjusted_compute(pastChunkSize, reusable, contextRemaining); + SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, contextRemaining); SizeType32 const computeIncrement = actualCompute - pastCompute; if ((ctxTokensCapacity && numCtxTokens + computeIncrement > ctxTokensCapacity.value()) @@ -181,24 +198,21 @@ void MicroBatchScheduler::setCtxRequestsChunkSizegetContextRemainingLength(); - // Reusable tokens are "free" — they don't consume forward-pass compute budget. SizeType32 const reusable = llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), suggestedChunkSize) : 0; - SizeType32 const computeCost = suggestedChunkSize - reusable; + SizeType32 const computeCost = reuse_adjusted_compute(suggestedChunkSize, reusable, suggestedChunkSize); SizeType32 actualChunkSize = suggestedChunkSize; if (ctxTokensCapacity && computeCost > ctxTokensCapacity.value()) { - // Model processes min(chunk_size, P - reusable) tokens starting from position reusable. - // To keep model tokens within budget: chunk_size <= capacity (not reusable + capacity). actualChunkSize = ctxTokensCapacity.value(); } if (maxContextLength) { - // maxContextLength limits compute tokens, not total tokens. - SizeType32 const actualCompute = std::max(0, actualChunkSize - reusable); + SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize); if (actualCompute > maxContextLength.value()) { - actualChunkSize = std::min(reusable + maxContextLength.value(), suggestedChunkSize); + actualChunkSize = maxContextLength.value(); + actualChunkSize = std::min(actualChunkSize, suggestedChunkSize); } } if (actualChunkSize != suggestedChunkSize) @@ -208,10 +222,7 @@ void MicroBatchScheduler::setCtxRequestsChunkSizesetContextChunkSize(actualChunkSize); if (ctxTokensCapacity) { - // Decrement by actual model token count: min(chunk_size, P - reusable). - // This equals min(actualChunkSize, computeCost) since computeCost = suggestedChunkSize - reusable. - SizeType32 const modelCost - = std::min(actualChunkSize, std::max(0, suggestedChunkSize - reusable)); + SizeType32 const modelCost = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize); ctxTokensCapacity = ctxTokensCapacity.value() - modelCost; } } @@ -349,10 +360,12 @@ std::tuple MicroBatchScheduler::operator()(Request if (!mCtxChunkConfig) // skip chunking { constexpr SizeType32 beam{0}; - reqNumTokens - = llmReq->getNumTokens(beam) + (llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0); - // Compute tokens = total - reusable (at least 1 to make progress) - SizeType32 const computeTokens = std::max(1, reqNumTokens - reusable); + SizeType32 const contextTokens = llmReq->getNumTokens(beam); + SizeType32 const draftTokens = llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0; + reqNumTokens = contextTokens + draftTokens; + SizeType32 const contextCompute + = reuse_adjusted_compute(contextTokens, reusable, llmReq->getContextRemainingLength()); + SizeType32 const computeTokens = std::max(1, contextCompute + draftTokens); TLLM_CHECK_WITH_INFO(!mMaxContextLength || computeTokens <= mMaxContextLength.value(), "Context compute tokens (%d) exceeds the limit value (%d)", computeTokens, mMaxContextLength.value()); @@ -369,9 +382,8 @@ std::tuple MicroBatchScheduler::operator()(Request llmReq->setContextChunkSize(llmReq->getContextRemainingLength()); auto const draftTokens = (llmReq->isLastContextChunk() && llmReq->hasDraftTokens()) ? llmReq->getNumDraftTokens() : 0; - // Compute cost: context compute + draft tokens - // (reusable tokens only offset context tokens, not draft tokens) - SizeType32 const contextCompute = std::max(0, llmReq->getContextChunkSize() - reusable); + SizeType32 const contextCompute = reuse_adjusted_compute( + llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength()); SizeType32 computeTokens = contextCompute + draftTokens; if (mMaxContextLength) @@ -444,10 +456,9 @@ std::tuple MicroBatchScheduler::operator()(Request if (llmReq->getContextChunkSize() > 0) { contextRequests.emplace_back(llmReq); - // Only count compute tokens (total - reusable). - // Reusable credit only applies to the first context chunk. SizeType32 const reusable = llmReq->isFirstContextChunk() ? llmReq->getEstimatedReusableTokens() : 0; - SizeType32 const computeTokens = std::max(0, llmReq->getContextChunkSize() - reusable); + SizeType32 const computeTokens + = reuse_adjusted_compute(llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength()); batchNumTokens += computeTokens; TLLM_LOG_DEBUG("context request scheduled: ID %lu, chunk size %d%s", llmReq->mRequestId, llmReq->getContextChunkSize(), reusable > 0 ? (", reusable " + std::to_string(reusable)).c_str() : ""); diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index fc161ab4a6c..b71c39d4087 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -70,8 +70,8 @@ void initBindings(nb::module_& m) nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt, - nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation", - nb::call_guard()); + nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0, + "Multi-head attention operation", nb::call_guard()); m.def( "get_helix_workspace_size_per_rank", diff --git a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp index 940d59258ca..f5a1336ea3e 100644 --- a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp +++ b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp @@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN namespace torch_ext { -void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache, - th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale) +void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache, + th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens) { - // Validate all tensors are CUDA tensors - TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() + // k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8 + // k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes + // slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used + // k_cache: [num_blocks, block_size, 1, per_token_size] uint8 + + TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() && slot_mapping_scale.is_cuda(), "All tensors must be CUDA tensors"); // Validate tensor dimensions - TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]"); - TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]"); - TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]"); - TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]"); - - // Enforce k_cache is 4D tensor - TORCH_CHECK(k_cache.dim() == 4, - "k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions", + TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]"); + TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]"); + TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]"); + TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]"); + TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims", static_cast(k_cache.dim())); - // Validate tensor dtypes - TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8"); - TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8"); + // Validate tensor dtypes — reinterpret_cast below assumes specific element sizes + TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size()); + TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d", + k_scale.element_size()); TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64"); TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64"); - // Validate tensor shapes are consistent - auto num_tokens = static_cast(k_fp8_bytes.size(0)); - TORCH_CHECK( - k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension"); - TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens"); - TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens"); - - // Validate tensors are contiguous (except k_cache which may be non-contiguous) - TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous"); - TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous"); - // k_cache can be non-contiguous - we handle this via strides + TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous"); + TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous"); TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous"); TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous"); - int32_t head_dim = static_cast(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128 - int32_t scale_size = static_cast(k_scale_bytes.size(1)); // scale_size = 4 bytes - - int32_t cache_dim_0 = static_cast(k_cache.size(0)); // num_blocks - int32_t cache_dim_1 = static_cast(k_cache.size(1)); // block_size - int32_t cache_dim_2 = static_cast(k_cache.size(2)); // num_kv_heads - int32_t cache_dim_3 = static_cast(k_cache.size(3)); // per_token_size - - // Validation for indexer k cache pool for DeepSeek-V3.2 constraints - TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2); - TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim); - TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size); - - int64_t cache_stride_0 = static_cast(k_cache.stride(0)); - int64_t cache_stride_1 = static_cast(k_cache.stride(1)); - int64_t cache_stride_2 = static_cast(k_cache.stride(2)); - int64_t cache_stride_3 = static_cast(k_cache.stride(3)); - - auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device()); - - tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr(), k_scale_bytes.data_ptr(), - k_cache.data_ptr(), slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), - num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, - cache_stride_1, cache_stride_2, cache_stride_3, stream); + // FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes. + int32_t const head_dim = static_cast(k_fp8.size(1)); + // Scale size in bytes: num_scale_elements * bytes_per_element. + int32_t const scale_size = static_cast(k_scale.size(1)) * static_cast(k_scale.element_size()); + + int32_t const cache_dim_0 = static_cast(k_cache.size(0)); + int32_t const cache_dim_1 = static_cast(k_cache.size(1)); + int32_t const cache_dim_2 = static_cast(k_cache.size(2)); + int32_t const cache_dim_3 = static_cast(k_cache.size(3)); + + TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2); + TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim); + TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size); + + int64_t const cache_stride_0 = static_cast(k_cache.stride(0)); + int64_t const cache_stride_1 = static_cast(k_cache.stride(1)); + int64_t const cache_stride_2 = static_cast(k_cache.stride(2)); + int64_t const cache_stride_3 = static_cast(k_cache.stride(3)); + + auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device()); + + // Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr. + // For slot mappings, use data_ptr directly — only the first num_tokens entries are read. + tk::invokeIndexerKCacheScatter(reinterpret_cast(k_fp8.data_ptr()), + reinterpret_cast(k_scale.data_ptr()), k_cache.data_ptr(), + slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), static_cast(num_tokens), + head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1, + cache_stride_2, cache_stride_3, stream); } } // namespace torch_ext @@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, " - "Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()"); + "indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, " + "Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 9a7af4da49f..b526310564e 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional k, std::optional cu_q_seqlens, std::optional cu_kv_seqlens, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, - std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits) + std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, + int64_t num_contexts, int64_t num_ctx_tokens) { TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -833,20 +834,9 @@ void attention(torch::Tensor q, std::optional k, std::optional(num_contexts); int32_t const num_tokens = qkv_or_q.size(0); - int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item(); - int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens; + int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - static_cast(num_ctx_tokens); auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item(); auto const gen_total_kv_len = host_total_kv_lens.index({1}).item(); diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h index 0fc4788d6f0..cc2b3f787f0 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.h +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -78,7 +78,8 @@ void attention(torch::Tensor q, std::optional k, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata = std::nullopt, - std::optional flash_mla_num_splits = std::nullopt); + std::optional flash_mla_num_splits = std::nullopt, int64_t num_contexts = 0, + int64_t num_ctx_tokens = 0); struct KvCachePoolPointers { diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 6d237cf01fe..1e63c3c111e 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1465,29 +1465,14 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, self.layer_idx) num_tokens = k_fp8.shape[0] - head_dim = k_fp8.shape[1] - scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes) - - # Convert to bytes: flatten first, then view as uint8, then reshape - k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view( - num_tokens, head_dim) - - # k_scale: for single-element tensors, contiguous() may be no-op - # Fix stride(-1) for byte-level view - k_scale_flat = k_scale.view(-1) - if k_scale_flat.stride(-1) != 1: - k_scale_flat = torch.as_strided(k_scale_flat.contiguous(), - size=(k_scale_flat.numel(), ), - stride=(1, )) - k_scale_bytes = k_scale_flat.view(torch.uint8).view( - num_tokens, scale_size) - - # Use CUDA kernel to scatter FP8 and scale bytes into cache - flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens] - flat_indices_scale = metadata.slot_mapping_scale[:num_tokens] - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache, flat_indices_fp8, - flat_indices_scale) + + # The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw + # bytes internally and only reads the first num_tokens entries from + # the slot mapping buffers, avoiding Python-side view/slice overhead. + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) def sparse_attn_indexer( self, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 7f1b7c0c4df..b1da0f66fe0 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -407,6 +407,8 @@ def run( mla_bmm1_scale: Optional[torch.Tensor] = None, mla_bmm2_scale: Optional[torch.Tensor] = None, quant_q_buffer: Optional[torch.Tensor] = None, + num_contexts: int = 0, + num_ctx_tokens: int = 0, ): """ Run the attention operation. @@ -638,6 +640,8 @@ def run( quant_q_buffer, self.quant_config, self.kv_cache_manager, + num_contexts, + num_ctx_tokens, global_layer_idx=self.global_layer_idx, ) else: @@ -722,6 +726,8 @@ def run( quant_q_buffer, self.flash_mla_tile_scheduler_metadata, self.flash_mla_num_splits, + num_contexts, + num_ctx_tokens, ) if self.print_skip_softmax_stat: @@ -2049,7 +2055,9 @@ def forward( fmha_scheduler_counter=fmha_scheduler_counter, mla_bmm1_scale=mla_bmm1_scale, mla_bmm2_scale=mla_bmm2_scale, - quant_q_buffer=quant_q_buffer) + quant_q_buffer=quant_q_buffer, + num_contexts=metadata.num_contexts, + num_ctx_tokens=metadata.num_ctx_tokens) if output_sf is None: return output diff --git a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py index 439831d4bf7..ff0ac200401 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py @@ -1437,23 +1437,6 @@ def run_mla_generation(self, params: EnqueueGenerationParams) -> None: params.context_buf.copy_(mla_out.reshape_as(params.context_buf)) -def _parse_request_types(host_request_types: torch.Tensor) -> Tuple[int, int]: - """ - Parse request types to count context and generation requests. - - Args: - host_request_types: Request types tensor (0=context, 1=generation). - num_seqs: Total number of sequences. - - Returns: - Tuple of (num_contexts, num_generations). - """ - - num_generations = host_request_types.sum().item() - num_contexts = host_request_types.size(0) - num_generations - return num_contexts, num_generations - - def is_supported( q: torch.Tensor, num_heads: int, @@ -1636,6 +1619,8 @@ def trtllm_gen_attention( quant_q_buffer: Optional[torch.Tensor], quant_config: Optional[QuantConfig], kv_cache_manager: Optional[KVCacheManager], + num_contexts: int, + num_ctx_tokens: int, global_layer_idx: Optional[int] = None, ) -> None: """ @@ -1766,20 +1751,9 @@ def trtllm_gen_attention( if attention_input_type is not None: attn_input_type = AttentionInputType(attention_input_type) - num_contexts, num_generations = _parse_request_types(host_request_types) - is_gen_only = attn_input_type == AttentionInputType.generation_only - is_ctx_only = attn_input_type == AttentionInputType.context_only - - if is_gen_only: - num_ctx_tokens = 0 - num_gen_tokens = num_tokens - elif is_ctx_only: - num_ctx_tokens = num_tokens - num_gen_tokens = 0 - else: - num_ctx_tokens = int(host_context_lengths[:num_contexts].sum()) if num_contexts > 0 else 0 - num_gen_tokens = num_tokens - num_ctx_tokens + num_generations = host_request_types.size(0) - num_contexts + num_gen_tokens = num_tokens - num_ctx_tokens # Prepare Workspace # Use upper-bound token counts for workspace sizing to avoid repeated diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index f1c99267ed0..5d7895fd0f1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -82,6 +82,9 @@ def __init__(self): # keeping a separate copy here since we sometimes have to overwrite the original values self.host_past_kv_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned self.host_context_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned + # Batch counts for thop.attention (updated every forward in plan_host) + self.num_contexts: int = 0 + self.num_ctx_tokens: int = 0 # Persistent block_offsets buffer for CUDA graph compatibility. # Pre-allocated to max size so the tensor address is stable across replays. self.block_offsets: Optional[torch.Tensor] = None @@ -171,6 +174,10 @@ def plan_host( """ num_seq = num_prefill + num_decode + # Batch counts for thop.attention + self.num_contexts = num_prefill + self.num_ctx_tokens = int(seq_len_host[:num_prefill].sum()) if num_prefill > 0 else 0 + # host_request_types: 0 = prefill (context), 1 = decode (generation) self.host_request_types[:num_prefill].fill_(0) self.host_request_types[num_prefill:num_seq].fill_(1) @@ -500,6 +507,8 @@ def trtllm_mha_with_cache( None, # mla_bmm1_scale None, # mla_bmm2_scale None, # quant_q_buffer + num_contexts=_GlobalTrtllmPlanner.num_contexts, + num_ctx_tokens=_GlobalTrtllmPlanner.num_ctx_tokens, ) if out is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7c2ce12a7f4..a8e35cb77ad 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -2636,6 +2636,24 @@ def previous_seq_slots_device(): num_tokens = len(input_ids) num_draft_tokens = len(draft_tokens) total_num_tokens = len(position_ids) + if total_num_tokens > self.max_num_tokens: + ctx_details = [] + for r in scheduled_requests.context_requests: + pos = r.context_current_position + csz = r.context_chunk_size + full = len(r.get_tokens(0)) + tokens = min(csz, max(0, full - pos)) + ctx_details.append( + f"rid={r.py_request_id} pos={pos} chunk={csz} " + f"full={full} tokens={tokens}") + gen_count = len(scheduled_requests.generation_requests) + from tensorrt_llm.logger import logger as _mnt_logger + _mnt_logger.error( + f"MNT overflow: total={total_num_tokens} " + f"max={self.max_num_tokens} " + f"ctx_reqs={len(scheduled_requests.context_requests)} " + f"gen_reqs={gen_count} " + f"ctx_breakdown=[{'; '.join(ctx_details)}]") assert total_num_tokens <= self.max_num_tokens, ( f"total_num_tokens ({total_num_tokens}) should be less than or equal to max_num_tokens ({self.max_num_tokens})" ) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 34012dde254..b5d0aaa1085 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2823,6 +2823,116 @@ def _balance_adp_requests(self, context_requests: list[LlmRequest], balanced_context_requests = context_requests return balanced_context_requests + @staticmethod + def _compute_scheduled_tokens(context_requests, generation_requests): + """Compute the total number of scheduled tokens for batch waiting decisions. + + For context requests, we estimate the actual compute tokens for this + iteration (excluding tokens served from KV cache). + + For generation requests, each contributes 1 + num_draft_tokens. + + Note on reusable token handling: + estimated_reusable_tokens is an absolute count from position 0. + Depending on the scheduler, context_current_position may or may not + have been advanced past the reusable prefix by the time this method + is called: + - V1 scheduler: prepare_context runs after scheduling, so + context_current_position is still 0. + - V2 scheduler: prepare_context runs during scheduling, so + context_current_position is already advanced to the reused offset. + To handle both correctly, the reusable credit applied to the current + chunk is max(0, reusable - context_current_position), i.e. only the + portion of the reusable range that falls within this chunk's span. + """ + num_scheduled_ctx_tokens = 0 + for ctx_req in context_requests: + reusable = (ctx_req.estimated_reusable_tokens + if ctx_req.is_first_context_chunk else 0) + # Credit only the reusable tokens that overlap with the current + # chunk: if context_current_position has already been advanced past + # the reusable prefix (V2), the credit is 0; if not (V1), the full + # reusable count is subtracted. + reusable_in_chunk = max(0, + reusable - ctx_req.context_current_position) + remaining = ctx_req.context_remaining_length + if (reusable_in_chunk > 0 and + reusable_in_chunk + ctx_req.context_chunk_size < remaining): + compute = ctx_req.context_chunk_size + else: + compute = max(1, ctx_req.context_chunk_size - reusable_in_chunk) + num_scheduled_ctx_tokens += compute + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + return num_scheduled_ctx_tokens + num_scheduled_gen_tokens + + def _maybe_log_batch_wait_decision( + self, + context_requests: list[LlmRequest], + generation_requests: list[LlmRequest], + num_scheduled_tokens: int, + wait_threshold: float, + should_waiting: bool, + ) -> None: + """Diagnostics for batch_wait: set TLLM_LOG_BATCH_WAIT=1 (rank 0 only).""" + if self.dist.rank != 0: + return + + num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens + for gen_req in generation_requests) + num_scheduled_ctx_formula = num_scheduled_tokens - num_scheduled_gen_tokens + + chunk_ctx_sum = 0 + ctx_summaries: List[str] = [] + max_detail = 4 + for i, ctx_req in enumerate(context_requests): + full_len = len(ctx_req.get_tokens(0)) + begin = ctx_req.context_current_position + chunk_sz = ctx_req.context_chunk_size + this_chunk = min(chunk_sz, max(0, full_len - begin)) + chunk_ctx_sum += this_chunk + reusable = (ctx_req.estimated_reusable_tokens + if ctx_req.is_first_context_chunk else 0) + reusable_in_chunk = max(0, reusable - begin) + remaining = ctx_req.context_remaining_length + if (reusable_in_chunk > 0 + and reusable_in_chunk + chunk_sz < remaining): + formula_contrib = chunk_sz + else: + formula_contrib = max(1, chunk_sz - reusable_in_chunk) + if i < max_detail: + ctx_summaries.append( + f"rid={ctx_req.py_request_id} full={full_len} pos={begin} " + f"chunk_sz={chunk_sz} this_chunk={this_chunk} " + f"reusable={reusable} formula_contrib={formula_contrib}") + n_ctx = len(context_requests) + if n_ctx > max_detail: + ctx_summaries.append(f"... +{n_ctx - max_detail} more ctx req(s)") + + logger.info( + "batch_wait: formula_total=", + num_scheduled_tokens, + " formula_ctx=", + num_scheduled_ctx_formula, + " formula_gen=", + num_scheduled_gen_tokens, + " chunk_ctx_sum=", + chunk_ctx_sum, + " threshold=", + wait_threshold, + " wait_iter=", + self.batch_wait_iters_count, + "/", + self.batch_wait_timeout_iters, + " should_defer_ctx=", + should_waiting, + " num_gen=", + len(generation_requests), + " ctx_detail=[", + "; ".join(ctx_summaries), + "]", + ) + def _waiting_requests(self, context_requests: list[LlmRequest], generation_requests: list[LlmRequest]): """ @@ -2832,13 +2942,16 @@ def _waiting_requests(self, context_requests: list[LlmRequest], - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`. """ - num_scheduled_ctx_tokens = sum( - len(ctx_req.get_tokens(0)) for ctx_req in context_requests) - num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens - for gen_req in generation_requests) - num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens + num_scheduled_tokens = self._compute_scheduled_tokens( + context_requests, generation_requests) + wait_threshold = (self.batch_wait_max_tokens_ratio * + self.max_num_tokens) - should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens + should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < wait_threshold + self._maybe_log_batch_wait_decision(context_requests, + generation_requests, + num_scheduled_tokens, + wait_threshold, should_waiting) if should_waiting: self.batch_wait_iters_count += 1 return [] diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 3c0fa9e6011..6b368eabc39 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -619,6 +619,30 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() + # Pre-addSequence budget re-validation. The C++ scheduler + # should already account for the chunk-shift cost, but under + # heavy KV-cache eviction the actual reuse may be lower than + # estimated. We re-probe the radix tree and estimate the + # true forward cost; if it exceeds the remaining budget the + # request is skipped (re-scheduled next iteration). + remaining_budget = None + if self.enable_block_reuse and not self.is_draft: + gen_tokens = sum( + req.get_beam_width_by_iter(for_next_iteration=False) + + get_draft_token_length(req) + for req in scheduled_batch.generation_requests) + remaining_budget = self.max_num_tokens - gen_tokens + + # Pre-subtract the fixed cost of non-first-chunk context + # requests. These have no reuse to re-validate and their + # compute cost is committed, so first-chunk budget checks + # must see the budget with these costs already removed. + for req in scheduled_batch.context_requests: + if not req.is_first_context_chunk: + remaining_budget -= req.context_chunk_size + + accepted_ctx_requests = [] + # allocate KV Cache for req in scheduled_batch.context_requests: req_beam_width = req.sampling_config.beam_width @@ -635,9 +659,37 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): else: if req.is_first_context_chunk and self._kv_connector_should_add_sequence( req): - self.impl.add_sequence(req.py_request_id, - req.prompt_len, req_beam_width, - req) + if remaining_budget is not None: + unique_tokens = req.get_unique_tokens(0) + reusable_blocks = self.impl.count_reusable_blocks( + unique_tokens, req, False) + actual_reuse = (reusable_blocks * + self.tokens_per_block) + req_compute = self._estimate_post_reuse_compute( + actual_reuse, req.context_chunk_size, + req.prompt_len) + if req_compute > remaining_budget: + logger.warning( + f"Reuse budget: skip req " + f"{req.py_request_id} " + f"(compute={req_compute}, " + f"chunk={req.context_chunk_size}, " + f"reuse={actual_reuse}, " + f"remaining={remaining_budget})") + continue + remaining_budget -= req_compute + + try: + self.impl.add_sequence(req.py_request_id, + req.prompt_len, + req_beam_width, req) + except RuntimeError: + logger.warning( + f"add_sequence: req " + f"{req.py_request_id} already exists, " + f"skipping") + accepted_ctx_requests.append(req) + continue for _ in range(self.num_extra_kv_tokens): self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): @@ -647,9 +699,20 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) + elif remaining_budget is not None and req.is_first_context_chunk: + # First-chunk request that skipped add_sequence + # (e.g. kv_connector said not to). Subtract its + # estimated cost so later first-chunk checks see a + # correct budget. Non-first-chunk costs were + # already pre-subtracted above. + reusable = req.estimated_reusable_tokens + remaining_budget -= self._estimate_post_reuse_compute( + reusable, req.context_chunk_size, req.prompt_len) + + accepted_ctx_requests.append(req) # A request may change from `context_requests_chunking` to `context_requests_last_chunk` in `add_sequence` due to KV cache reuse, so we rebuild the context request lists here. - scheduled_batch.reset_context_requests() + scheduled_batch.reset_context_requests(accepted_ctx_requests) for req in scheduled_batch.generation_requests: if self.mapping.has_cp_helix(): @@ -674,6 +737,23 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.kv_connector_manager.build_scheduler_output( scheduled_batch, self) + def _estimate_post_reuse_compute(self, reuse_tokens: int, chunk_size: int, + prompt_len: int) -> int: + """Estimate forward compute tokens after setPrepopulatedPromptLen. + + For non-last chunks the chunk window shifts right by the reused + amount and the forward cost is approximately chunk_size. For + last chunks the cost is prompt_len - reuse (original formula). + """ + P = reuse_tokens + if P > 0 and P < prompt_len: + if P + chunk_size < prompt_len: + aligned_end = ((P + chunk_size) // self.tokens_per_block * + self.tokens_per_block) + return max(1, aligned_end - P) + return max(1, prompt_len - P) + return chunk_size + def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool: return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence( request) diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index b1054a27dd4..0bf2506885a 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -716,7 +716,7 @@ def test_indexer_k_cache_scatter_custom_op(): dtype=torch.bfloat16) k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original) - # Prepare byte-level data + # Prepare byte-level data for the Python reference path scale_size = k_scale.shape[1] * 4 k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim) k_scale_flat = k_scale.view(-1) @@ -754,10 +754,11 @@ def test_indexer_k_cache_scatter_custom_op(): print(f" is_contiguous: {k_cache_python.is_contiguous()}") # ========== Path 1: CUDA Kernel ========== - print("\n=== Path 1: CUDA Kernel ===") - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache_cuda, flat_indices_fp8, - flat_indices_scale) + print(f"\n=== Path 1: CUDA Kernel ===") + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache_cuda, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) torch.cuda.synchronize() print("✓ CUDA kernel completed")