Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,72 @@ class WindowBlockManager
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock);

//! \brief Per-request block allocation statistics from batch addSequence.
struct BatchSeqStats
{
SizeType32 prepopulatedLen{0};
SizeType32 allocTotalDelta{0};
SizeType32 allocNewDelta{0};
SizeType32 reusedDelta{0};
SizeType32 missedDelta{0};
};

//! \brief Result of Phase 1 (claim-only) of batch addSequence.
//! \details Holds matched blocks and prepared data so Phase 2 can proceed without
//! re-traversing the radix tree.
struct ClaimResult
{
struct ClaimedBlock
{
BlockPtr block;
SizeType32 numMatchedTokens; //!< tokens matched in this block
bool isPartialMatch;
bool needsCopy; //!< partial match on block with refs or non-leaf (needs getFreeBlock + copy in Phase 2)
bool isPlaceholder; //!< placeholder block (linear attention recurrent states)
};

std::vector<ClaimedBlock> claimedBlocks;
BlockPtr claimedCopySource; //!< unreferenced non-leaf partial-match source claimed to protect from eviction
SizeType32 totalMatchedTokens{0};
SizeType32 latestMatchingNonPlaceholderBlockIdx{-1};
SizeType32 numSharedContextBlocks{0};
SizeType32 numContextBlocks{0};
bool shareLastContextBlockAmongBeams{true};
std::vector<BlockKey> blockKeys;
std::vector<executor::RetentionPriorityAndDuration> perBlockRetentions;
executor::KvCacheTransferMode mode{executor::KvCacheTransferMode::DRAM};
std::string directory;
};

//! \brief Tracks which request currently "owns" a partially-matched leaf block across
//! the batch Phase 1 loop, so that at most one request reuses the block in-place
//! while all others copy.
struct PartialClaimTracker
{
struct Entry
{
size_t requestIdx; //!< index of the request that currently owns the reuse
size_t claimedIdx; //!< index into that request's claimedBlocks vector
bool fullyMatched; //!< true once any request fully matches this block
};

//! Keyed by block ID.
std::unordered_map<KVCacheBlock::IdType, Entry> map;
};

//! \brief Batch add sequences with two-phase claim-then-onboard under a single lock.
//! \details Phase 1 claims all matching blocks across all requests (protecting from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! The mCachedBlocksRootMutex is held for the entire operation.
//! \param sequences Per-request GenerationRequest references (parallel with other vectors).
//! \param inputLengths Per-request effective input length.
//! \param numContextBlocksVec Per-request number of context blocks.
//! \param llmRequests Per-request LlmRequest references.
//! \return Per-request prepopulatedPromptLen.
[[nodiscard]] std::vector<BatchSeqStats> addSequenceBatch(std::vector<GenerationRequest*> const& sequences,
std::vector<SizeType32> const& inputLengths, std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests);

//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams);
Expand Down Expand Up @@ -1090,6 +1156,20 @@ class WindowBlockManager
bool shareLastContextBlockAmongBeams, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");

//! \brief Phase 1: Walk radix tree and claim matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
//! Uses \p tracker to coordinate partial-match ownership across requests in
//! the same batch. \p claimResults is the full vector so that a previous
//! request's ClaimedBlock can be retroactively marked needsCopy.
[[nodiscard]] ClaimResult claimMatchingBlocks(GenerationRequest& sequence, SizeType32 inputLength,
SizeType32 numContextBlocks, LlmRequest& llmRequest, size_t requestIdx, PartialClaimTracker& tracker,
std::vector<ClaimResult>& claimResults);

//! \brief Phase 2 (lock-free): Onboard claimed host blocks and allocate non-matching blocks.
//! \details Caller must hold mCachedBlocksRootMutex.
[[nodiscard]] SizeType32 onboardAndAllocateBlocks(
GenerationRequest& sequence, LlmRequest& llmRequest, ClaimResult& claimResult);

//! \brief Free block and all it's descendants. This makes block a claimed leaf block.
void freeChildren(BlockPtr const& block);

Expand Down Expand Up @@ -1299,6 +1379,12 @@ class BlockManager
void addSequence(
GenerationRequest& sequence, SizeType32 numContextBlocks, SizeType32 windowSize, bool isShareLastContextBlock);

//! \brief Batch add sequences forwarding to WindowBlockManager::addSequenceBatch.
[[nodiscard]] std::vector<WindowBlockManager::BatchSeqStats> addSequenceBatch(
std::vector<GenerationRequest*> const& sequences, std::vector<SizeType32> const& inputLengths,
std::vector<SizeType32> const& numContextBlocksVec,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests, SizeType32 windowSize);

void allocateBlock(GenerationRequest& sequence, SizeType32 windowSize);

//! \brief According to request's current position, copy data from the last full block to the next block (ignoring
Expand Down Expand Up @@ -1806,6 +1892,15 @@ class BaseKVCacheManager
OptionalRef<LlmRequest> llmRequest = std::nullopt)
= 0;

//! \brief Batch add sequences with two-phase claim-then-onboard to prevent host offloading eviction.
//! \details Phase 1 claims all matching blocks across all requests (protecting them from eviction).
//! Phase 2 onboards host blocks and allocates non-matching blocks.
//! Requires block reuse enabled and single attention window.
virtual void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests)
= 0;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

[[nodiscard]] virtual std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false)
= 0;
Expand Down Expand Up @@ -2181,6 +2276,10 @@ class KVCacheManager : public BaseKVCacheManager
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
OptionalRef<LlmRequest> llmRequest = std::nullopt) override;

void addSequenceBatch(
std::vector<std::tuple<LlmRequest::RequestIdType, SizeType32, SizeType32>> const& requestInfos,
std::vector<std::reference_wrapper<LlmRequest>> const& llmRequests) override;

[[nodiscard]] std::optional<KVCacheBlock::IdType> removeSequence(LlmRequest::RequestIdType requestId,
OptionalRef<LlmRequest const> llmRequest = std::nullopt, bool pinOnRelease = false) override;

Expand Down
Loading
Loading