Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ namespace tensorrt_llm::batch_manager

using SizeType32 = MicroBatchScheduler::SizeType32;

/// Return the forward-pass token cost for a context chunk with KV cache reuse.
///
/// setPrepopulatedPromptLen shifts the chunk window right by the prepopulated
/// amount rather than shrinking it. For non-last chunks the model still
/// processes approximately @p chunkSize tokens; only for the last chunk is the
/// cost @p contextRemaining - @p reusable.
static SizeType32 reuse_adjusted_compute(SizeType32 chunkSize, SizeType32 reusable, SizeType32 contextRemaining)
{
if (reusable <= 0)
{
return chunkSize;
}
if (reusable + chunkSize < contextRemaining)
{
return chunkSize;
}
return std::max<SizeType32>(0, contextRemaining - reusable);
}

MicroBatchScheduler::MicroBatchScheduler(std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig,
std::optional<SizeType32> maxContextLength, LlmRequestState noScheduleUntilState,
LlmRequestState noScheduleAfterState)
Expand All @@ -38,16 +57,15 @@ void MicroBatchScheduler::fitDraftTokens(RequestVector& contextsToBeChunked,
std::optional<SizeType32> ctxTokensCapacity, SizeType32 const chunkUnitSize,
std::optional<SizeType32> const& maxContextLength)
{
// How many compute tokens (chunk - reusable) are in this batch already?
// How many compute tokens are in this batch already?
SizeType32 numCtxTokens{0};
for (auto const& llmReq : contextsToBeChunked)
{
SizeType32 const chunkSize = llmReq->getContextChunkSize();
// contextRemaining = P for first chunk; used to compute actual model token count.
SizeType32 const contextRemaining = llmReq->getContextRemainingLength();
SizeType32 const reusable
= llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0;
numCtxTokens += std::min(chunkSize, std::max<SizeType32>(0, contextRemaining - reusable));
numCtxTokens += reuse_adjusted_compute(chunkSize, reusable, contextRemaining);
}

// Discard draft tokens that won't fit into the existing chunk unit, max
Expand Down Expand Up @@ -125,14 +143,13 @@ void MicroBatchScheduler::setCtxRequestsChunkSize<MicroBatchScheduler::ContextCh
SizeType32 actualChunkSize = llmReq->getContextChunkSize();
SizeType32 actualIncrement = actualChunkSize - pastChunkSize;

// Compute-aware budget: reusable tokens are served from cache and do not
// consume forward-pass capacity. Only the tokens beyond the reusable prefix count.
SizeType32 const reusable = llmReq->isFirstContextChunk()
? std::min(llmReq->getEstimatedReusableTokens(), llmReq->getContextRemainingLength())
: 0;
SizeType32 const pastCompute = std::max<SizeType32>(0, pastChunkSize - std::min(reusable, pastChunkSize));
SizeType32 const actualCompute
= std::max<SizeType32>(0, actualChunkSize - std::min(reusable, actualChunkSize));
// Compute-aware budget accounting for setPrepopulatedPromptLen's
// chunk-shift behaviour (non-last chunks keep their full size).
SizeType32 const contextRemaining = llmReq->getContextRemainingLength();
SizeType32 const reusable
= llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), contextRemaining) : 0;
SizeType32 const pastCompute = reuse_adjusted_compute(pastChunkSize, reusable, contextRemaining);
SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, contextRemaining);
SizeType32 const computeIncrement = actualCompute - pastCompute;

if ((ctxTokensCapacity && numCtxTokens + computeIncrement > ctxTokensCapacity.value())
Expand Down Expand Up @@ -181,24 +198,21 @@ void MicroBatchScheduler::setCtxRequestsChunkSize<MicroBatchScheduler::ContextCh
for (auto& llmReq : contextsToBeChunked)
{
SizeType32 const suggestedChunkSize = llmReq->getContextRemainingLength();
// Reusable tokens are "free" — they don't consume forward-pass compute budget.
SizeType32 const reusable
= llmReq->isFirstContextChunk() ? std::min(llmReq->getEstimatedReusableTokens(), suggestedChunkSize) : 0;
SizeType32 const computeCost = suggestedChunkSize - reusable;
SizeType32 const computeCost = reuse_adjusted_compute(suggestedChunkSize, reusable, suggestedChunkSize);
SizeType32 actualChunkSize = suggestedChunkSize;
if (ctxTokensCapacity && computeCost > ctxTokensCapacity.value())
{
// Model processes min(chunk_size, P - reusable) tokens starting from position reusable.
// To keep model tokens within budget: chunk_size <= capacity (not reusable + capacity).
actualChunkSize = ctxTokensCapacity.value();
}
if (maxContextLength)
{
// maxContextLength limits compute tokens, not total tokens.
SizeType32 const actualCompute = std::max<SizeType32>(0, actualChunkSize - reusable);
SizeType32 const actualCompute = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize);
if (actualCompute > maxContextLength.value())
{
actualChunkSize = std::min<SizeType32>(reusable + maxContextLength.value(), suggestedChunkSize);
actualChunkSize = maxContextLength.value();
actualChunkSize = std::min(actualChunkSize, suggestedChunkSize);
}
}
if (actualChunkSize != suggestedChunkSize)
Expand All @@ -208,10 +222,7 @@ void MicroBatchScheduler::setCtxRequestsChunkSize<MicroBatchScheduler::ContextCh
llmReq->setContextChunkSize(actualChunkSize);
if (ctxTokensCapacity)
{
// Decrement by actual model token count: min(chunk_size, P - reusable).
// This equals min(actualChunkSize, computeCost) since computeCost = suggestedChunkSize - reusable.
SizeType32 const modelCost
= std::min(actualChunkSize, std::max<SizeType32>(0, suggestedChunkSize - reusable));
SizeType32 const modelCost = reuse_adjusted_compute(actualChunkSize, reusable, suggestedChunkSize);
ctxTokensCapacity = ctxTokensCapacity.value() - modelCost;
}
}
Expand Down Expand Up @@ -349,10 +360,12 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
if (!mCtxChunkConfig) // skip chunking
{
constexpr SizeType32 beam{0};
reqNumTokens
= llmReq->getNumTokens(beam) + (llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0);
// Compute tokens = total - reusable (at least 1 to make progress)
SizeType32 const computeTokens = std::max(1, reqNumTokens - reusable);
SizeType32 const contextTokens = llmReq->getNumTokens(beam);
SizeType32 const draftTokens = llmReq->hasDraftTokens() ? llmReq->getNumDraftTokens() : 0;
reqNumTokens = contextTokens + draftTokens;
SizeType32 const contextCompute
= reuse_adjusted_compute(contextTokens, reusable, llmReq->getContextRemainingLength());
SizeType32 const computeTokens = std::max(1, contextCompute + draftTokens);
TLLM_CHECK_WITH_INFO(!mMaxContextLength || computeTokens <= mMaxContextLength.value(),
"Context compute tokens (%d) exceeds the limit value (%d)", computeTokens,
mMaxContextLength.value());
Expand All @@ -369,9 +382,8 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
llmReq->setContextChunkSize(llmReq->getContextRemainingLength());
auto const draftTokens
= (llmReq->isLastContextChunk() && llmReq->hasDraftTokens()) ? llmReq->getNumDraftTokens() : 0;
// Compute cost: context compute + draft tokens
// (reusable tokens only offset context tokens, not draft tokens)
SizeType32 const contextCompute = std::max(0, llmReq->getContextChunkSize() - reusable);
SizeType32 const contextCompute = reuse_adjusted_compute(
llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength());
SizeType32 computeTokens = contextCompute + draftTokens;

if (mMaxContextLength)
Expand Down Expand Up @@ -444,10 +456,9 @@ std::tuple<RequestVector, RequestVector> MicroBatchScheduler::operator()(Request
if (llmReq->getContextChunkSize() > 0)
{
contextRequests.emplace_back(llmReq);
// Only count compute tokens (total - reusable).
// Reusable credit only applies to the first context chunk.
SizeType32 const reusable = llmReq->isFirstContextChunk() ? llmReq->getEstimatedReusableTokens() : 0;
SizeType32 const computeTokens = std::max(0, llmReq->getContextChunkSize() - reusable);
SizeType32 const computeTokens
= reuse_adjusted_compute(llmReq->getContextChunkSize(), reusable, llmReq->getContextRemainingLength());
batchNumTokens += computeTokens;
TLLM_LOG_DEBUG("context request scheduled: ID %lu, chunk size %d%s", llmReq->mRequestId,
llmReq->getContextChunkSize(), reusable > 0 ? (", reusable " + std::to_string(reusable)).c_str() : "");
Expand Down
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ void initBindings(nb::module_& m)
nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt,
nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt,
nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt,
nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0,
"Multi-head attention operation", nb::call_guard<nb::gil_scoped_release>());

m.def(
"get_helix_workspace_size_per_rank",
Expand Down
101 changes: 49 additions & 52 deletions cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN
namespace torch_ext
{

void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale)
void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache,
th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens)
{
// Validate all tensors are CUDA tensors
TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
// k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8
// k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes
// slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used
// k_cache: [num_blocks, block_size, 1, per_token_size] uint8

TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda()
&& slot_mapping_scale.is_cuda(),
"All tensors must be CUDA tensors");

// Validate tensor dimensions
TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]");
TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]");

// Enforce k_cache is 4D tensor
TORCH_CHECK(k_cache.dim() == 4,
"k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions",
TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]");
TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]");
TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]");
TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]");
TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims",
static_cast<int>(k_cache.dim()));

// Validate tensor dtypes
TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8");
TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8");
// Validate tensor dtypes — reinterpret_cast below assumes specific element sizes
TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size());
TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d",
k_scale.element_size());
TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64");
TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64");

// Validate tensor shapes are consistent
auto num_tokens = static_cast<int32_t>(k_fp8_bytes.size(0));
TORCH_CHECK(
k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension");
TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens");
TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens");

// Validate tensors are contiguous (except k_cache which may be non-contiguous)
TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous");
TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous");
// k_cache can be non-contiguous - we handle this via strides
TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous");
TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous");
TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous");
TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous");

int32_t head_dim = static_cast<int32_t>(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128
int32_t scale_size = static_cast<int32_t>(k_scale_bytes.size(1)); // scale_size = 4 bytes

int32_t cache_dim_0 = static_cast<int32_t>(k_cache.size(0)); // num_blocks
int32_t cache_dim_1 = static_cast<int32_t>(k_cache.size(1)); // block_size
int32_t cache_dim_2 = static_cast<int32_t>(k_cache.size(2)); // num_kv_heads
int32_t cache_dim_3 = static_cast<int32_t>(k_cache.size(3)); // per_token_size

// Validation for indexer k cache pool for DeepSeek-V3.2 constraints
TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size);

int64_t cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));

auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device());

tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr<uint8_t>(), k_scale_bytes.data_ptr<uint8_t>(),
k_cache.data_ptr<uint8_t>(), slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(),
num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0,
cache_stride_1, cache_stride_2, cache_stride_3, stream);
// FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes.
int32_t const head_dim = static_cast<int32_t>(k_fp8.size(1));
// Scale size in bytes: num_scale_elements * bytes_per_element.
int32_t const scale_size = static_cast<int32_t>(k_scale.size(1)) * static_cast<int32_t>(k_scale.element_size());

int32_t const cache_dim_0 = static_cast<int32_t>(k_cache.size(0));
int32_t const cache_dim_1 = static_cast<int32_t>(k_cache.size(1));
int32_t const cache_dim_2 = static_cast<int32_t>(k_cache.size(2));
int32_t const cache_dim_3 = static_cast<int32_t>(k_cache.size(3));

TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2);
TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim);
TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size);

int64_t const cache_stride_0 = static_cast<int64_t>(k_cache.stride(0));
int64_t const cache_stride_1 = static_cast<int64_t>(k_cache.stride(1));
int64_t const cache_stride_2 = static_cast<int64_t>(k_cache.stride(2));
int64_t const cache_stride_3 = static_cast<int64_t>(k_cache.stride(3));

auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device());

// Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr.
// For slot mappings, use data_ptr directly — only the first num_tokens entries are read.
tk::invokeIndexerKCacheScatter(reinterpret_cast<uint8_t const*>(k_fp8.data_ptr()),
reinterpret_cast<uint8_t const*>(k_scale.data_ptr()), k_cache.data_ptr<uint8_t>(),
slot_mapping_fp8.data_ptr<int64_t>(), slot_mapping_scale.data_ptr<int64_t>(), static_cast<int32_t>(num_tokens),
head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1,
cache_stride_2, cache_stride_3, stream);
}

} // namespace torch_ext
Expand All @@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()");
"indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, "
"Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()");
}

TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
Expand Down
Loading
Loading