From 98c23e71c817b34a0390d393859c62b0b335c15f Mon Sep 17 00:00:00 2001 From: Kan Date: Thu, 18 Dec 2025 02:10:23 -0800 Subject: [PATCH 1/3] add flashinfer enum --- examples/inference/gpt/gpt_dynamic_inference.py | 11 ++++++++++- megatron/core/inference/contexts/dynamic_context.py | 4 ++++ megatron/core/transformer/enums.py | 3 +++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 6c2a539ce7..e8a1aa7bf0 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -31,6 +31,7 @@ ContextOverflowError, DynamicInferenceContext, ) +from megatron.core.transformer.enums import AttnBackend from megatron.core.inference.contexts.attention_context.mamba_metadata import ( MambaInferenceStateConfig, ) @@ -158,6 +159,13 @@ def get_inference_context( if args.inference_logging_step_interval > 0 and args.inference_wandb_logging: metrics_writer = get_wandb_writer() + # Use smaller block size for flashinfer backends + block_size = ( + 16 + if hasattr(args, 'attention_backend') and args.attention_backend in [AttnBackend.flashinfer_fa2, AttnBackend.flashinfer_fa3, AttnBackend.flashinfer_trt] + else args.inference_dynamic_batching_block_size + ) + # Inference context. context = DynamicInferenceContext( params_dtype=args.params_dtype, @@ -172,7 +180,7 @@ def get_inference_context( if args.cuda_graph_impl == "local" else None ), - block_size_tokens=args.inference_dynamic_batching_block_size, + block_size_tokens=block_size, buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb, max_requests=args.inference_dynamic_batching_max_requests, max_tokens=args.inference_dynamic_batching_max_tokens, @@ -184,6 +192,7 @@ def get_inference_context( qk_pos_emb_head_dim=args.qk_pos_emb_head_dim, use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs, use_flashinfer_fused_rope=args.use_flashinfer_fused_rope, + attention_backend=getattr(args, 'attention_backend', AttnBackend.flash), unified_memory_level=args.inference_dynamic_batching_unified_memory_level, cuda_graph_max_tokens=args.inference_dynamic_batching_cuda_graph_max_tokens, cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count, diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 6e70d71fe2..d2277c6fcb 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -12,6 +12,7 @@ from torch import Tensor from megatron.core import parallel_state +from megatron.core.transformer.enums import AttnBackend from megatron.core.inference.batch_dimensions_utils import ( CUDAGraphBatchDimensionBuilder, InferenceBatchDimensions, @@ -240,6 +241,7 @@ class DynamicInferenceContext(BaseInferenceContext): levels will be included to control other tensors within the context. use_flashinfer_fused_rope (bool): If True, use flashinfer's fused rope implementation. If None, defaults to using flash-infer if available. + attention_backend (AttnBackend): Attention backend to use. Defaults to AttnBackend.flash. metrics_writer (Optional['WandbModule']): Wandb module for writing metrics. request_metadata_types (Optional[List[Tuple[str, torch.dtype, bool]]]): A list of the per-request metadata types to track. Each entry is a tuple consisting of the string @@ -271,6 +273,7 @@ def __init__( mamba_inference_state_config: Optional[MambaInferenceStateConfig] = None, use_cuda_graphs_for_non_decode_steps: bool = True, use_flashinfer_fused_rope: bool = False, + attention_backend: AttnBackend = AttnBackend.flash, unified_memory_level: Optional[int] = 1, cuda_graph_max_tokens: Optional[int] = None, cuda_graph_mixed_prefill_count: Optional[int] = 16, @@ -279,6 +282,7 @@ def __init__( ): super().__init__(materialize_only_last_token_logits=materialize_only_last_token_logits) + self.attention_backend = attention_backend self.cache_mla_latent = cache_mla_latent if self.cache_mla_latent: assert ( diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index 52b82029f9..300d14d367 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -65,3 +65,6 @@ class AttnBackend(enum.Enum): unfused = 3 local = 4 auto = 5 + flashinfer_fa2 = 6 + flashinfer_fa3 = 7 + flashinfer_trt = 8 From ec96bc132467cd4a303de3bebd82a2ecdd50aa8e Mon Sep 17 00:00:00 2001 From: Kan Date: Thu, 18 Dec 2025 03:32:09 -0800 Subject: [PATCH 2/3] add flashinfer enum --- .../inference/contexts/dynamic_context.py | 135 +++---- .../contexts/fused_kv_append_kernel.py | 379 ++++++++++++++---- megatron/core/inference/kv_cache.py | 304 ++++++++++++++ .../contexts/test_fused_kv_append.py | 308 ++++++++++++++ 4 files changed, 968 insertions(+), 158 deletions(-) create mode 100644 megatron/core/inference/kv_cache.py create mode 100644 tests/unit_tests/inference/contexts/test_fused_kv_append.py diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index d2277c6fcb..f66d6f2d4c 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -38,6 +38,7 @@ from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata from .base_context import BaseInferenceContext from .dynamic_block_allocator import BlockAllocator +from ..kv_cache import KVCacheBase, KVCacheLayout, MLACache, create_mhagqa_cache try: from .fused_kv_append_kernel import triton_append_key_value_cache @@ -585,34 +586,44 @@ def allocate_all_tensors(self, *, is_init: bool) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) - # Memory buffer. + # Determine cache layout based on attention backend. + if self.cache_mla_latent: + self._cache_layout = None # MLA uses its own layout + elif self.attention_backend in [ + AttnBackend.flashinfer_fa2, + AttnBackend.flashinfer_fa3, + AttnBackend.flashinfer_trt, + ]: + self._cache_layout = KVCacheLayout.M_N2HCD # FlashInfer layout + else: + self._cache_layout = KVCacheLayout.M_2NCHD # Flash backend (default) + + # Memory buffer - list of cache objects, one per attention layer. + # Indexed via layer_map[global_layer_idx] to get attention layer index. def allocate_memory_buffer(): """Allocate the memory buffer. This function is called below within `with ctx_manager:`.""" - if self.cache_mla_latent: - self.memory_buffer = torch.empty( - ( - self.num_attention_layers, - self.block_allocator.total_count, - self.block_size_tokens, - self.kv_reduced_dim, - ), - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - else: - self.memory_buffer = torch.empty( - ( - 2, # key and value - self.num_attention_layers, - self.block_allocator.total_count, - self.block_size_tokens, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ), - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) + self.memory_buffer: List[KVCacheBase] = [] + for _ in range(self.num_attention_layers): + if self.cache_mla_latent: + cache = MLACache( + num_chunks=self.block_allocator.total_count, + chunk_size=self.block_size_tokens, + kv_reduced_dim=self.kv_reduced_dim, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + else: + cache = create_mhagqa_cache( + layout=self._cache_layout, + num_chunks=self.block_allocator.total_count, + chunk_size=self.block_size_tokens, + num_kv_heads=self.num_attention_heads_per_partition, + head_dim=self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + self.memory_buffer.append(cache) # Optional state tensors for hybrid models def allocate_mamba_states(): @@ -810,75 +821,59 @@ def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) """Append to KV cache. Args: - layer_number (int): Layer number. + layer_number (int): Layer number (1-based). key (Tensor): Key tensor. value (Tensor): Value tensor. """ attention_layer_number = self.layer_map[layer_number - 1] + cache = self.memory_buffer[attention_layer_number] - if triton_append_key_value_cache is not None and not self.cache_mla_latent: - # currently does not support MLA latent cache + # Use Triton kernel if cache supports it + if triton_append_key_value_cache is not None and cache.supports_triton(): return triton_append_key_value_cache( - layer_number=attention_layer_number, key=key, value=value, - memory_buffer=self.memory_buffer, + cache=cache, padded_active_token_count=self.padded_active_token_count, token_to_block_idx=self.token_to_block_idx, token_to_local_position_within_kv_block=self.token_to_local_position_within_kv_block, ) - block_idx = self.token_to_block_idx[: self.padded_active_token_count] - local_kv_seq_idx = self.token_to_local_position_within_kv_block[ - : self.padded_active_token_count - ] - - if not self.cache_mla_latent: - assert key.size(1) == 1 and value.size(1) == 1 - - key = key.squeeze(1) - # There is no value cache in FlashMLA/absorption - if not self.cache_mla_latent: - value = value.squeeze(1) - - if self.cache_mla_latent: - # We pass the kv_concat as the key in cache_mla_latent - kv_concat = key - self.memory_buffer[attention_layer_number, block_idx, local_kv_seq_idx] = kv_concat[ - : self.padded_active_token_count - ] - else: - self.memory_buffer[0, attention_layer_number, block_idx, local_kv_seq_idx] = key[ - : self.padded_active_token_count - ] - self.memory_buffer[1, attention_layer_number, block_idx, local_kv_seq_idx] = value[ - : self.padded_active_token_count - ] + # Fallback: use cache's append method + cache.append( + key=key, + value=value, + padded_active_token_count=self.padded_active_token_count, + token_to_block_idx=self.token_to_block_idx, + token_to_local_position_within_kv_block=self.token_to_local_position_within_kv_block, + ) - def key_value_cache(self, layer_number: int) -> Tuple[Tensor, Tensor]: + def key_value_cache(self, layer_number: int) -> Tuple[Tensor, Optional[Tensor], Tensor]: """Read from KV cache. Args: - layer_number (int): Layer number. + layer_number (int): Layer number (1-based). Return: - (Tuple[Tensor, Tensor]) The key and value pointer tensors that point - to blocks within the block-level memory buffer. + (Tuple[Tensor, Optional[Tensor], Tensor]) The key cache, value cache (or None for MLA), + and block table tensor. """ attention_layer_number = self.layer_map[layer_number - 1] + cache = self.memory_buffer[attention_layer_number] + cache_content = cache.get_content() + block_table = self.active_attn_metadata["mha_metadata"].state_data["block_table"] if self.cache_mla_latent: - return ( - self.memory_buffer[attention_layer_number], - None, - self.active_attn_metadata["mha_metadata"].state_data["block_table"], - ) + # MLA: cache_content is single tensor [N, C, D] + return (cache_content, None, block_table) + elif self._cache_layout == KVCacheLayout.M_2NCHD: + # M_2NCHD: [2, N, C, H, D] - slice on dim 0 + return (cache_content[0], cache_content[1], block_table) + elif self._cache_layout == KVCacheLayout.M_N2HCD: + # M_N2HCD: [N, 2, H, C, D] - slice on dim 1 + return (cache_content[:, 0], cache_content[:, 1], block_table) else: - return ( - self.memory_buffer[0, attention_layer_number], - self.memory_buffer[1, attention_layer_number], - self.active_attn_metadata["mha_metadata"].state_data["block_table"], - ) + raise ValueError(f"Unknown cache layout: {self._cache_layout}") def mamba_states_cache(self, layer_number: int) -> Tuple[Tensor, Tensor]: """Returns the Mamba state tensors for the given layer.""" diff --git a/megatron/core/inference/contexts/fused_kv_append_kernel.py b/megatron/core/inference/contexts/fused_kv_append_kernel.py index db1eed456e..42b3a59e44 100644 --- a/megatron/core/inference/contexts/fused_kv_append_kernel.py +++ b/megatron/core/inference/contexts/fused_kv_append_kernel.py @@ -1,17 +1,23 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from typing import Optional + import triton import triton.language as tl from torch import Tensor +# ============================================================================ +# TRITON KERNELS +# ============================================================================ + + @triton.jit -def _append_kv_cache_kernel( +def _append_kv_merged_kernel( # --- Pointers to Tensors --- key_ptr, value_ptr, - key_cache_ptr, - value_cache_ptr, + cache_ptr, block_idx_ptr, local_kv_seq_idx_ptr, # --- Strides for Tensor Memory Layout --- @@ -21,6 +27,7 @@ def _append_kv_cache_kernel( stride_value_token, stride_value_head, stride_value_hdim, + stride_cache_kv, stride_cache_block, stride_cache_pos, stride_cache_head, @@ -33,17 +40,11 @@ def _append_kv_cache_kernel( BLOCK_SIZE_H: tl.constexpr, ): """ - Triton kernel to append key and value vectors to pre-sliced paged KV cache tensors. + Triton kernel for merged KV cache layouts (M_2NCHD, M_N2HCD, M_N2CHD). + Handles caches where K and V are stored in a single tensor with a KV dimension. Each program instance handles one head of one token. The grid is 2D: (n_tokens, num_heads). - - 1. It identifies which token and head it is responsible for using `tl.program_id`. - 2. It loads the `block_idx` and `local_pos` for that token. - 3. It loads the `h_dim` vector for its assigned key/value head. - 4. It calculates the destination address in the 4D cache slices. - 5. It writes (scatters) the head vector to its destination in the cache. """ - token_idx = tl.program_id(0) head_idx = tl.program_id(1) @@ -64,111 +65,313 @@ def _append_kv_cache_kernel( key_to_write = tl.load(key_head_ptr + offs_h * stride_key_hdim, mask=mask_h, other=0.0) value_to_write = tl.load(value_head_ptr + offs_h * stride_value_hdim, mask=mask_h, other=0.0) - # --- Calculate destination pointers in the 4D KV cache slices --- - dest_offset = ( - block_idx * stride_cache_block + local_pos * stride_cache_pos + head_idx * stride_cache_head + # --- Calculate destination pointers in the merged cache --- + # The stride_cache_kv allows us to select between K (0) and V (1) + base_offset = ( + block_idx * stride_cache_block + + local_pos * stride_cache_pos + + head_idx * stride_cache_head ) - key_dest_ptr = key_cache_ptr + dest_offset - value_dest_ptr = value_cache_ptr + dest_offset + key_dest_ptr = cache_ptr + 0 * stride_cache_kv + base_offset + value_dest_ptr = cache_ptr + 1 * stride_cache_kv + base_offset # --- Store the head data into the cache --- tl.store(key_dest_ptr + offs_h * stride_cache_hdim, key_to_write, mask=mask_h) tl.store(value_dest_ptr + offs_h * stride_cache_hdim, value_to_write, mask=mask_h) -def triton_append_key_value_cache( - layer_number: int, +@triton.jit +def _append_mla_kernel( + # --- Pointers to Tensors --- + kv_concat_ptr, + cache_ptr, + block_idx_ptr, + local_kv_seq_idx_ptr, + # --- Strides for Tensor Memory Layout --- + stride_kv_token, + stride_kv_dim, + stride_cache_block, + stride_cache_pos, + stride_cache_dim, + # --- Other Parameters --- + n_tokens: tl.int32, + LATENT_DIM: tl.int32, + # --- Compile-Time Constants --- + BLOCK_SIZE_D: tl.constexpr, +): + """ + Triton kernel for MLA cache layout. + Handles compressed latent representation (no K/V split, no heads). + + Each program instance handles one token. The grid is 1D: (n_tokens,). + """ + token_idx = tl.program_id(0) + + if token_idx >= n_tokens: + return + + # --- Load destination indices for the current token --- + block_idx = tl.load(block_idx_ptr + token_idx) + local_pos = tl.load(local_kv_seq_idx_ptr + token_idx) + + # --- Load the latent representation for the current token --- + offs_d = tl.arange(0, BLOCK_SIZE_D) + mask_d = offs_d < LATENT_DIM + + kv_ptr = kv_concat_ptr + token_idx * stride_kv_token + kv_to_write = tl.load(kv_ptr + offs_d * stride_kv_dim, mask=mask_d, other=0.0) + + # --- Calculate destination pointer in the MLA cache --- + dest_offset = block_idx * stride_cache_block + local_pos * stride_cache_pos + + dest_ptr = cache_ptr + dest_offset + + # --- Store the latent data into the cache --- + tl.store(dest_ptr + offs_d * stride_cache_dim, kv_to_write, mask=mask_d) + + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + + +def _validate_and_prepare_tensors( + key: Tensor, value: Optional[Tensor], n_tokens: int +) -> Tuple[Tensor, Optional[Tensor], int, int, int]: + """ + Validate input tensors and extract common dimensions. + + Args: + key: Key tensor of shape (batch_size, 1, num_heads, h_dim) or (batch_size, 1, latent_dim) + value: Value tensor of shape (batch_size, 1, num_heads, h_dim) or None for MLA + n_tokens: Number of tokens to process + + Returns: + Tuple of (squeezed_key, squeezed_value, num_heads, h_dim, n_tokens) + """ + assert key.device.type == 'cuda', "All tensors must be on CUDA devices." + if value is not None: + assert value.device.type == 'cuda', "All tensors must be on CUDA devices." + + assert key.size(1) == 1, "Key should have a sequence length of 1." + key = key.squeeze(1) + + if value is not None: + assert value.size(1) == 1, "Value should have a sequence length of 1." + value = value.squeeze(1) + + if n_tokens == 0: + return key, value, 0, 0, 0 + + # Extract dimensions + if key.dim() == 3: # [batch, heads, dim] + _, num_heads, h_dim = key.shape + elif key.dim() == 2: # [batch, dim] for MLA + num_heads = 0 + h_dim = key.size(-1) + else: + raise ValueError(f"Unexpected key shape: {key.shape}") + + return key, value, num_heads, h_dim, n_tokens + + +# ============================================================================ +# GENERIC WRAPPERS +# ============================================================================ + + +def _append_merged_cache( key: Tensor, value: Tensor, - memory_buffer: Tensor, + cache: Tensor, + n_tokens: int, + num_heads: int, + h_dim: int, + block_idx_active: Tensor, + local_kv_seq_idx_active: Tensor, + kv_dim_idx: int, + block_dim_idx: int, + pos_dim_idx: int, + head_dim_idx: int, +) -> None: + """ + Generic wrapper for merged cache layouts. + + Args: + kv_dim_idx: Index of the KV dimension (2) in the cache shape + block_dim_idx: Index of the block/chunk dimension (N) in the cache shape + pos_dim_idx: Index of the position dimension (C) in the cache shape + head_dim_idx: Index of the head dimension (H) in the cache shape + """ + grid = (n_tokens, num_heads) + BLOCK_SIZE_H = triton.next_power_of_2(h_dim) + + cache_strides = cache.stride() + stride_cache_kv = cache_strides[kv_dim_idx] + stride_cache_block = cache_strides[block_dim_idx] + stride_cache_pos = cache_strides[pos_dim_idx] + stride_cache_head = cache_strides[head_dim_idx] + stride_cache_hdim = cache_strides[-1] # Last dimension is always head_dim + + _append_kv_merged_kernel[grid]( + key, + value, + cache, + block_idx_active, + local_kv_seq_idx_active, + key.stride(0), + key.stride(1), + key.stride(2), + value.stride(0), + value.stride(1), + value.stride(2), + stride_cache_kv, + stride_cache_block, + stride_cache_pos, + stride_cache_head, + stride_cache_hdim, + n_tokens=n_tokens, + num_heads=num_heads, + H_DIM=h_dim, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) + + +def _append_mla_cache( + kv_concat: Tensor, + cache: Tensor, + n_tokens: int, + latent_dim: int, + block_idx_active: Tensor, + local_kv_seq_idx_active: Tensor, +) -> None: + """ + Wrapper for MLA cache layout: [N, C, D] + """ + grid = (n_tokens,) + BLOCK_SIZE_D = triton.next_power_of_2(latent_dim) + + cache_strides = cache.stride() + + _append_mla_kernel[grid]( + kv_concat, + cache, + block_idx_active, + local_kv_seq_idx_active, + kv_concat.stride(0), + kv_concat.stride(1), + cache_strides[0], # stride_cache_block + cache_strides[1], # stride_cache_pos + cache_strides[2], # stride_cache_dim + n_tokens=n_tokens, + LATENT_DIM=latent_dim, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + + +# ============================================================================ +# MAIN DISPATCHER +# ============================================================================ + + +def triton_append_key_value_cache( + key: Tensor, + value: Optional[Tensor], + cache, # KVCacheBase instance padded_active_token_count: int, token_to_block_idx: Tensor, token_to_local_position_within_kv_block: Tensor, ) -> None: """ - Append to KV cache using a high-performance, standalone Triton kernel. + Append to KV cache using high-performance Triton kernels. + + This function supports the following cache layouts: + - M_2NCHD: Merged [2, N, C, H, D] + - M_N2HCD: Merged [N, 2, H, C, D] + - MLA: [N, C, D] Args: - layer_number (int): Layer number (1-based). - key (Tensor): Key tensor of shape (batch_size, 1, num_heads, h_dim). - value (Tensor): Value tensor of shape (batch_size, 1, num_heads, h_dim). - memory_buffer (Tensor): The 6D KV cache tensor to write to. + key (Tensor): Key tensor of shape (batch_size, 1, num_heads, h_dim) or + (batch_size, 1, latent_dim) for MLA. + value (Optional[Tensor]): Value tensor of shape (batch_size, 1, num_heads, h_dim) + or None for MLA. + cache: KVCacheBase instance (must support Triton). padded_active_token_count (int): The number of active tokens to process. token_to_block_idx (Tensor): Tensor mapping token index to its block index in the cache. token_to_local_position_within_kv_block (Tensor): Tensor mapping token index to its position within a block. """ + # Import cache classes (avoid circular imports by importing locally) + from megatron.core.inference.kv_cache import ( + KVCacheM2NCHD, + KVCacheMN2HCD, + MLACache, + ) + # --- Input Validation and Preparation --- - assert ( - key.device.type == 'cuda' - and value.device.type == 'cuda' - and memory_buffer.device.type == 'cuda' - ), "All tensors must be on CUDA devices." - - assert ( - key.size(1) == 1 and value.size(1) == 1 - ), "Key and Value should have a sequence length of 1." - key = key.squeeze(1) - value = value.squeeze(1) + key, value, num_heads, h_dim, n_tokens = _validate_and_prepare_tensors( + key, value, padded_active_token_count + ) - n_tokens = padded_active_token_count if n_tokens == 0: return - _, num_heads, h_dim = key.shape - - key_cache = memory_buffer[0, layer_number] - value_cache = memory_buffer[1, layer_number] - + # Get active slices key_to_cache = key[:n_tokens] - value_to_cache = value[:n_tokens] - block_idx_active = token_to_block_idx[:n_tokens] - local_kv_seq_idx_active = token_to_local_position_within_kv_block[:n_tokens] - - assert ( - key_cache.dim() == 4 and value_cache.dim() == 4 - ), f"Sliced key_cache and value_cache should be 4D" - assert ( - num_heads == key_cache.shape[-2] - ), f"Head count mismatch. Key/Value has {num_heads} but cache expects {key_cache.shape[-2]}." - assert ( - h_dim == key_cache.shape[-1] - ), f"Head dimension mismatch. Key/Value has {h_dim} but cache expects {key_cache.shape[-1]}." - - block_idx_active = block_idx_active.contiguous() - local_kv_seq_idx_active = local_kv_seq_idx_active.contiguous() - - grid = (n_tokens, num_heads) - BLOCK_SIZE_H = triton.next_power_of_2(h_dim) + value_to_cache = value[:n_tokens] if value is not None else None + block_idx_active = token_to_block_idx[:n_tokens].contiguous() + local_kv_seq_idx_active = token_to_local_position_within_kv_block[:n_tokens].contiguous() - cache_strides = key_cache.stride() + # Get cache tensors from the cache object + cache_content = cache.get_content() - _append_kv_cache_kernel[grid]( - # Pointers - key_to_cache, - value_to_cache, - key_cache, - value_cache, - block_idx_active, - local_kv_seq_idx_active, - # Strides for 3D key/value tensors - key_to_cache.stride(0), - key_to_cache.stride(1), - key_to_cache.stride(2), - value_to_cache.stride(0), - value_to_cache.stride(1), - value_to_cache.stride(2), - # Strides for the 4D sliced cache - cache_strides[0], - cache_strides[1], - cache_strides[2], - cache_strides[3], - # Other parameters - n_tokens=n_tokens, - num_heads=num_heads, - H_DIM=h_dim, - # Compile-time constant - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) + # Dispatch based on cache type + if isinstance(cache, MLACache): + # MLA cache: [N, C, D] + _append_mla_cache( + kv_concat=key_to_cache, + cache=cache_content, + n_tokens=n_tokens, + latent_dim=h_dim, + block_idx_active=block_idx_active, + local_kv_seq_idx_active=local_kv_seq_idx_active, + ) + elif isinstance(cache, KVCacheM2NCHD): + # M_2NCHD: [2, N, C, H, D] - KV at dim 0, block at 1, pos at 2, head at 3 + _append_merged_cache( + key=key_to_cache, + value=value_to_cache, + cache=cache_content, + n_tokens=n_tokens, + num_heads=num_heads, + h_dim=h_dim, + block_idx_active=block_idx_active, + local_kv_seq_idx_active=local_kv_seq_idx_active, + kv_dim_idx=0, + block_dim_idx=1, + pos_dim_idx=2, + head_dim_idx=3, + ) + elif isinstance(cache, KVCacheMN2HCD): + # M_N2HCD: [N, 2, H, C, D] - block at 0, KV at 1, head at 2, pos at 3 + _append_merged_cache( + key=key_to_cache, + value=value_to_cache, + cache=cache_content, + n_tokens=n_tokens, + num_heads=num_heads, + h_dim=h_dim, + block_idx_active=block_idx_active, + local_kv_seq_idx_active=local_kv_seq_idx_active, + kv_dim_idx=1, + block_dim_idx=0, + pos_dim_idx=3, + head_dim_idx=2, + ) + else: + raise TypeError( + f"Unsupported cache type: {type(cache).__name__}. " + f"Triton kernel only supports M_2NCHD, M_N2HCD, and MLA layouts." + ) diff --git a/megatron/core/inference/kv_cache.py b/megatron/core/inference/kv_cache.py new file mode 100644 index 0000000000..4a71983754 --- /dev/null +++ b/megatron/core/inference/kv_cache.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""KV Cache implementations for different memory layouts.""" + +import warnings +from abc import ABC, abstractmethod +from enum import Enum +from typing import Optional + +import torch +from torch import Tensor + + +class KVCacheEfficiencyWarning(UserWarning): + """Custom warning for inefficient KV cache operations.""" + + pass + + +class KVCacheLayout(Enum): + """ + Enum representing the different KV cache memory layouts. + Note: Layer dimension is NOT included - it's handled outside the cache. + + The names correspond to the data layout: + M = Merged + 2 = K/V dimension + N = Chunks, C = Chunk Size, H = Heads, D = Head Dimension + """ + + M_2NCHD = "KVCacheM2NCHD" + """Merged cache layout: [2, Chunks, ChunkSize, Heads, Dim]""" + + M_N2HCD = "KVCacheMN2HCD" + """Merged cache layout: [Chunks, 2, Heads, ChunkSize, Dim]""" + + +class KVCacheBase(ABC): + """ + Base class for KV cache implementations. + Each cache instance represents a single layer's cache. + """ + + def __init__( + self, + num_chunks: int, + chunk_size: int, + num_kv_heads: int, + head_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + self.num_chunks: int = num_chunks + self.chunk_size: int = chunk_size + self.num_kv_heads: int = num_kv_heads + self.head_dim: int = head_dim + self.device: Optional[torch.device] = device + self.dtype: Optional[torch.dtype] = dtype + + @abstractmethod + def get_content(self) -> Tensor: + """ + Returns the cache content tensor. + + Returns: + The cache tensor with K/V data. + """ + raise NotImplementedError + + @abstractmethod + def append( + self, + key: Tensor, + value: Tensor, + padded_active_token_count: int, + token_to_block_idx: Tensor, + token_to_local_position_within_kv_block: Tensor, + ) -> None: + """ + Appends key-value pairs to the cache. + + Args: + key: Key tensor to append + value: Value tensor to append + padded_active_token_count: Number of active tokens + token_to_block_idx: Mapping from token to block index + token_to_local_position_within_kv_block: Mapping from token to position within block + """ + raise NotImplementedError + + def supports_triton(self) -> bool: + """ + Returns True if this cache layout is compatible with Triton kernels. + All layouts (M_2NCHD, M_N2HCD, MLA) are Triton-compatible. + """ + return False + + +class MLACache(KVCacheBase): + """ + Cache for Multi-Latent Attention (MLA). + Stores compressed latent representation instead of full K/V. + Layout: [Chunks, ChunkSize, LatentDim] + """ + + def __init__( + self, + num_chunks: int, + chunk_size: int, + kv_reduced_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + # MLA doesn't use num_kv_heads or head_dim in the same way + super().__init__( + num_chunks=num_chunks, + chunk_size=chunk_size, + num_kv_heads=0, # Not used for MLA + head_dim=kv_reduced_dim, # Reuse head_dim for latent dim + device=device, + dtype=dtype, + ) + self.kv_reduced_dim = kv_reduced_dim + self.cache: Tensor = torch.full( + (num_chunks, chunk_size, kv_reduced_dim), + -1, + dtype=dtype, + device=device, + ) + + def get_content(self) -> Tensor: + """Returns the MLA latent cache tensor.""" + return self.cache + + def append( + self, + key: Tensor, + value: Tensor, + padded_active_token_count: int, + token_to_block_idx: Tensor, + token_to_local_position_within_kv_block: Tensor, + ) -> None: + """ + Append latent representation to MLA cache. + For MLA, 'key' contains the concatenated latent representation. + """ + block_idx = token_to_block_idx[:padded_active_token_count] + local_kv_seq_idx = token_to_local_position_within_kv_block[:padded_active_token_count] + + # For MLA, key contains the kv_concat latent representation + kv_concat = key.squeeze(1) + self.cache[block_idx, local_kv_seq_idx] = kv_concat[:padded_active_token_count] + + def supports_triton(self) -> bool: + """MLA is Triton-compatible.""" + return True + + +class KVCacheM2NCHD(KVCacheBase): + """ + Merged KV cache with shape [2, Chunks, ChunkSize, Heads, Dim]. + Layout: 2, N, C, H, D + Triton-compatible for Flash backend. + """ + + def __init__( + self, + num_chunks: int, + chunk_size: int, + num_kv_heads: int, + head_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(num_chunks, chunk_size, num_kv_heads, head_dim, device, dtype) + self.cache: Tensor = torch.full( + (2, num_chunks, chunk_size, num_kv_heads, head_dim), + -1, + dtype=dtype, + device=device, + ) + + def get_content(self) -> Tensor: + """Returns the merged cache tensor.""" + return self.cache + + def append( + self, + key: Tensor, + value: Tensor, + padded_active_token_count: int, + token_to_block_idx: Tensor, + token_to_local_position_within_kv_block: Tensor, + ) -> None: + """Append K/V to merged cache.""" + block_idx = token_to_block_idx[:padded_active_token_count] + local_kv_seq_idx = token_to_local_position_within_kv_block[:padded_active_token_count] + + assert key.size(1) == 1 and value.size(1) == 1, "Expected sequence length of 1" + key = key.squeeze(1) + value = value.squeeze(1) + + self.cache[0, block_idx, local_kv_seq_idx] = key[:padded_active_token_count] + self.cache[1, block_idx, local_kv_seq_idx] = value[:padded_active_token_count] + + def supports_triton(self) -> bool: + """M_2NCHD is Triton-compatible.""" + return True + + +class KVCacheMN2HCD(KVCacheBase): + """ + Merged KV cache with shape [Chunks, 2, Heads, ChunkSize, Dim]. + Layout: N, 2, H, C, D + Used by FlashInfer backends (fa2, fa3, trt). + """ + + def __init__( + self, + num_chunks: int, + chunk_size: int, + num_kv_heads: int, + head_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__(num_chunks, chunk_size, num_kv_heads, head_dim, device, dtype) + self.cache: Tensor = torch.full( + (num_chunks, 2, num_kv_heads, chunk_size, head_dim), + -1, + dtype=dtype, + device=device, + ) + + def get_content(self) -> Tensor: + """Returns the merged cache tensor.""" + return self.cache + + def append( + self, + key: Tensor, + value: Tensor, + padded_active_token_count: int, + token_to_block_idx: Tensor, + token_to_local_position_within_kv_block: Tensor, + ) -> None: + """Append K/V to merged cache.""" + block_idx = token_to_block_idx[:padded_active_token_count] + local_kv_seq_idx = token_to_local_position_within_kv_block[:padded_active_token_count] + + assert key.size(1) == 1 and value.size(1) == 1, "Expected sequence length of 1" + key = key.squeeze(1) + value = value.squeeze(1) + + self.cache[block_idx, 0, :, local_kv_seq_idx, :] = key[:padded_active_token_count] + self.cache[block_idx, 1, :, local_kv_seq_idx, :] = value[:padded_active_token_count] + + def supports_triton(self) -> bool: + """M_N2HCD is Triton-compatible.""" + return True + + +# --- Factory Function --- +def create_mhagqa_cache( + layout: KVCacheLayout, + num_chunks: int, + chunk_size: int, + num_kv_heads: int, + head_dim: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> KVCacheBase: + """ + Factory function to create a KV cache instance with a specified layout. + + Args: + layout: The desired memory layout from the KVCacheLayout enum. + num_chunks: Number of chunks (blocks) in the cache. + chunk_size: The number of tokens per chunk (block). + num_kv_heads: The number of key/value heads. + head_dim: The dimension of each head. + device: The torch device to create the cache on. + dtype: The torch dtype for the cache tensor. + + Returns: + An instance of a KVCacheBase subclass. + """ + layout_to_class = { + KVCacheLayout.M_2NCHD: KVCacheM2NCHD, + KVCacheLayout.M_N2HCD: KVCacheMN2HCD, + } + + cache_class = layout_to_class.get(layout) + if cache_class is None: + raise ValueError(f"Unknown KV cache layout: {layout}") + + return cache_class( + num_chunks=num_chunks, + chunk_size=chunk_size, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + device=device, + dtype=dtype, + ) diff --git a/tests/unit_tests/inference/contexts/test_fused_kv_append.py b/tests/unit_tests/inference/contexts/test_fused_kv_append.py new file mode 100644 index 0000000000..d12f848efd --- /dev/null +++ b/tests/unit_tests/inference/contexts/test_fused_kv_append.py @@ -0,0 +1,308 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for Triton-based KV cache append operations.""" + +import pytest +import torch + +from megatron.core.inference.contexts.fused_kv_append_kernel import triton_append_key_value_cache +from megatron.core.inference.kv_cache import KVCacheLayout, MLACache, create_mhagqa_cache + + +class TestFusedKVAppend: + """Test Triton-based KV cache append operations for all layouts.""" + + @pytest.fixture + def cache_params(self): + """Common cache parameters for testing.""" + return { + 'num_chunks': 8, + 'chunk_size': 64, + 'num_kv_heads': 8, + 'head_dim': 128, + 'dtype': torch.float16, + 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), + } + + @pytest.fixture + def sample_data(self, cache_params): + """Generate sample K/V data for testing.""" + batch_size = 5 + device = cache_params['device'] + dtype = cache_params['dtype'] + + key = torch.randn( + batch_size, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=dtype, device=device + ) + value = torch.randn( + batch_size, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=dtype, device=device + ) + + token_to_block_idx = torch.tensor([0, 1, 2, 3, 4], device=device) + token_to_local_pos = torch.tensor([0, 5, 10, 15, 20], device=device) + + return key, value, token_to_block_idx, token_to_local_pos + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize( + "layout", + [ + KVCacheLayout.M_2NCHD, + KVCacheLayout.M_N2HCD, + ], + ) + def test_triton_append_mhagqa(self, layout, cache_params, sample_data): + """Test Triton append for all MHA/GQA layouts.""" + cache = create_mhagqa_cache(layout=layout, **cache_params) + key, value, block_idx, local_pos = sample_data + + # Verify cache supports Triton + assert cache.supports_triton(), f"{layout} should support Triton" + + # Append using Triton + triton_append_key_value_cache( + key=key, + value=value, + cache=cache, + padded_active_token_count=len(key), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Verify data was written correctly by checking specific positions + cache_content = cache.get_content() + + # For merged caches + for i, (b_idx, l_pos) in enumerate(zip(block_idx.tolist(), local_pos.tolist())): + if layout == KVCacheLayout.M_2NCHD: + # [2, N, C, H, D] + cached_k = cache_content[0, b_idx, l_pos, :, :] + cached_v = cache_content[1, b_idx, l_pos, :, :] + elif layout == KVCacheLayout.M_N2HCD: + # [N, 2, H, C, D] + cached_k = cache_content[b_idx, 0, :, l_pos, :] + cached_v = cache_content[b_idx, 1, :, l_pos, :] + + assert torch.allclose(cached_k, key[i, 0, :, :], rtol=1e-3, atol=1e-3) + assert torch.allclose(cached_v, value[i, 0, :, :], rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_triton_append_mla(self, cache_params): + """Test Triton append for MLA cache.""" + kv_reduced_dim = 256 + mla_cache = MLACache( + num_chunks=cache_params['num_chunks'], + chunk_size=cache_params['chunk_size'], + kv_reduced_dim=kv_reduced_dim, + dtype=cache_params['dtype'], + device=cache_params['device'], + ) + + # Verify MLA cache supports Triton + assert mla_cache.supports_triton(), "MLA cache should support Triton" + + batch_size = 5 + kv_concat = torch.randn( + batch_size, 1, kv_reduced_dim, + dtype=cache_params['dtype'], + device=cache_params['device'] + ) + block_idx = torch.tensor([0, 1, 2, 3, 4], device=cache_params['device']) + local_pos = torch.tensor([0, 5, 10, 15, 20], device=cache_params['device']) + + # Append using Triton (value=None for MLA) + triton_append_key_value_cache( + key=kv_concat, + value=None, + cache=mla_cache, + padded_active_token_count=len(kv_concat), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Verify data was written correctly + cache_content = mla_cache.get_content() + for i, (b_idx, l_pos) in enumerate(zip(block_idx.tolist(), local_pos.tolist())): + cached_kv = cache_content[b_idx, l_pos, :] + assert torch.allclose(cached_kv, kv_concat[i, 0, :], rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.parametrize( + "layout", + [ + KVCacheLayout.M_2NCHD, + KVCacheLayout.M_N2HCD, + ], + ) + def test_triton_vs_pytorch(self, layout, cache_params, sample_data): + """Compare Triton kernel output with PyTorch fallback.""" + # Create two identical caches + cache_triton = create_mhagqa_cache(layout=layout, **cache_params) + cache_pytorch = create_mhagqa_cache(layout=layout, **cache_params) + + key, value, block_idx, local_pos = sample_data + + # Triton path + triton_append_key_value_cache( + key=key, + value=value, + cache=cache_triton, + padded_active_token_count=len(key), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # PyTorch path (using cache's native append method) + cache_pytorch.append( + key=key, + value=value, + padded_active_token_count=len(key), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Compare results + content_triton = cache_triton.get_content() + content_pytorch = cache_pytorch.get_content() + + # Merged cache + assert torch.allclose(content_triton, content_pytorch, rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_triton_vs_pytorch_mla(self, cache_params): + """Compare Triton kernel output with PyTorch fallback for MLA.""" + kv_reduced_dim = 256 + + # Create two identical MLA caches + cache_triton = MLACache( + num_chunks=cache_params['num_chunks'], + chunk_size=cache_params['chunk_size'], + kv_reduced_dim=kv_reduced_dim, + dtype=cache_params['dtype'], + device=cache_params['device'], + ) + cache_pytorch = MLACache( + num_chunks=cache_params['num_chunks'], + chunk_size=cache_params['chunk_size'], + kv_reduced_dim=kv_reduced_dim, + dtype=cache_params['dtype'], + device=cache_params['device'], + ) + + batch_size = 5 + kv_concat = torch.randn( + batch_size, 1, kv_reduced_dim, + dtype=cache_params['dtype'], + device=cache_params['device'] + ) + block_idx = torch.tensor([0, 1, 2, 3, 4], device=cache_params['device']) + local_pos = torch.tensor([0, 5, 10, 15, 20], device=cache_params['device']) + + # Triton path + triton_append_key_value_cache( + key=kv_concat, + value=None, + cache=cache_triton, + padded_active_token_count=len(kv_concat), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # PyTorch path + cache_pytorch.append( + key=kv_concat, + value=None, + padded_active_token_count=len(kv_concat), + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Compare results + content_triton = cache_triton.get_content() + content_pytorch = cache_pytorch.get_content() + assert torch.allclose(content_triton, content_pytorch, rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_empty_batch(self, cache_params): + """Test handling of empty batches.""" + cache = create_mhagqa_cache(layout=KVCacheLayout.M_2NCHD, **cache_params) + + # Empty tensors + key = torch.empty(0, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + value = torch.empty(0, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + block_idx = torch.empty(0, dtype=torch.long, device=cache_params['device']) + local_pos = torch.empty(0, dtype=torch.long, device=cache_params['device']) + + # Should not raise an error + triton_append_key_value_cache( + key=key, + value=value, + cache=cache, + padded_active_token_count=0, + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_single_token(self, cache_params): + """Test appending a single token.""" + cache = create_mhagqa_cache(layout=KVCacheLayout.M_2NCHD, **cache_params) + + key = torch.randn(1, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + value = torch.randn(1, 1, cache_params['num_kv_heads'], cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + block_idx = torch.tensor([3], device=cache_params['device']) + local_pos = torch.tensor([42], device=cache_params['device']) + + triton_append_key_value_cache( + key=key, + value=value, + cache=cache, + padded_active_token_count=1, + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Verify - M_2NCHD has layout [2, N, C, H, D] + cache_content = cache.get_content() + assert torch.allclose(cache_content[0, 3, 42, :, :], key[0, 0, :, :], rtol=1e-3, atol=1e-3) + assert torch.allclose(cache_content[1, 3, 42, :, :], value[0, 0, :, :], rtol=1e-3, atol=1e-3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_full_block(self, cache_params): + """Test filling an entire block.""" + cache = create_mhagqa_cache(layout=KVCacheLayout.M_N2HCD, **cache_params) + + chunk_size = cache_params['chunk_size'] + key = torch.randn(chunk_size, 1, cache_params['num_kv_heads'], + cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + value = torch.randn(chunk_size, 1, cache_params['num_kv_heads'], + cache_params['head_dim'], + dtype=cache_params['dtype'], device=cache_params['device']) + + # Fill block 0 completely + block_idx = torch.zeros(chunk_size, dtype=torch.long, device=cache_params['device']) + local_pos = torch.arange(chunk_size, device=cache_params['device']) + + triton_append_key_value_cache( + key=key, + value=value, + cache=cache, + padded_active_token_count=chunk_size, + token_to_block_idx=block_idx, + token_to_local_position_within_kv_block=local_pos, + ) + + # Verify all positions in block 0 are filled + cache_content = cache.get_content() # [N, 2, H, C, D] + for i in range(chunk_size): + cached_k = cache_content[0, 0, :, i, :] # block 0, K, all heads, pos i + cached_v = cache_content[0, 1, :, i, :] # block 0, V, all heads, pos i + assert torch.allclose(cached_k, key[i, 0, :, :], rtol=1e-3, atol=1e-3) + assert torch.allclose(cached_v, value[i, 0, :, :], rtol=1e-3, atol=1e-3) From fa8b31648377c695b04a4f287b6da9f95a6a1336 Mon Sep 17 00:00:00 2001 From: Kan Date: Sat, 27 Dec 2025 15:43:21 -0800 Subject: [PATCH 3/3] add main impl --- .../inference/gpt/gpt_dynamic_inference.py | 3 +- .../attention_context/flashinfer_metadata.py | 485 ++++++++++++++++++ .../inference/contexts/dynamic_context.py | 109 +++- .../contexts/fused_kv_append_kernel.py | 2 +- megatron/core/transformer/attention.py | 45 +- megatron/rl/inference/megatron.py | 3 +- .../contexts/test_dynamic_context.py | 3 +- .../inference/engines/test_dynamic_engine.py | 3 +- .../inference/test_wandb_logging.py | 6 +- .../test_simple_text_generation_controller.py | 3 +- .../models/test_gpt_model_batch_invariant.py | 6 +- tools/run_inference_performance_test.py | 3 +- 12 files changed, 635 insertions(+), 36 deletions(-) create mode 100644 megatron/core/inference/contexts/attention_context/flashinfer_metadata.py diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index e8a1aa7bf0..4e3e2d70e8 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -171,9 +171,10 @@ def get_inference_context( params_dtype=args.params_dtype, num_layers=args.num_layers // args.pipeline_model_parallel_size, kv_channels=args.kv_channels, - num_attention_heads=( + num_attention_kv_heads=( args.num_query_groups if args.group_query_attention else args.num_attention_heads ), + num_attention_qo_heads=args.num_attention_heads, max_sequence_length=max_sequence_length, num_cuda_graphs=( args.inference_dynamic_batching_num_cuda_graphs diff --git a/megatron/core/inference/contexts/attention_context/flashinfer_metadata.py b/megatron/core/inference/contexts/attention_context/flashinfer_metadata.py new file mode 100644 index 0000000000..b113d0cb53 --- /dev/null +++ b/megatron/core/inference/contexts/attention_context/flashinfer_metadata.py @@ -0,0 +1,485 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +FlashInfer attention metadata for dynamic inference. + +Provides a unified metadata class that uses simple dispatch logic: +- All decode requests → BatchDecodeWithPagedKVCacheWrapper +- Otherwise (any prefill) → BatchPrefillWithPagedKVCacheWrapper +""" + +import torch +from torch import Tensor + +from megatron.core.inference.batch_dimensions_utils import InferenceBatchDimensions +from megatron.core.transformer.enums import AttnBackend + +from .metadata_base import MetadataBase + +try: + import flashinfer + + HAVE_FLASHINFER = True +except ImportError: + HAVE_FLASHINFER = False + + +class FlashInferMetadata(MetadataBase): + """ + Unified FlashInfer attention metadata. + + Simple dispatch logic: + - All decode requests → BatchDecodeWithPagedKVCacheWrapper + - Otherwise (any prefill) → BatchPrefillWithPagedKVCacheWrapper + """ + + # Map AttnBackend enum to FlashInfer backend strings + BACKEND_MAP = { + AttnBackend.flashinfer_fa2: "fa2", + AttnBackend.flashinfer_fa3: "fa3", + AttnBackend.flashinfer_trt: "trtllm-gen", + } + + def __init__( + self, + max_requests: int, + max_kv_block_count: int, + block_size_tokens: int, + backend: AttnBackend, + workspace_size: int = 512 * 1024 * 1024, # 512MB default + ): + """ + Initialize FlashInfer metadata. + + Args: + max_requests: Maximum number of concurrent requests + max_kv_block_count: Maximum number of KV blocks per request + block_size_tokens: Number of tokens per KV block (page size) + backend: FlashInfer backend type (fa2, fa3, trt) + workspace_size: Size of workspace buffer for FlashInfer wrappers + """ + super().__init__() + + if not HAVE_FLASHINFER: + raise ImportError("flashinfer is required for FlashInfer attention backend") + + self.device = torch.cuda.current_device() + self.max_requests = max_requests + self.max_kv_block_count = max_kv_block_count + self.block_size_tokens = block_size_tokens + self.backend = backend + self.flashinfer_backend = self.BACKEND_MAP.get(backend, "fa2") + + # Model parameters (set via set_model_params) + self._num_qo_heads = None + self._num_kv_heads = None + self._head_dim = None + self._params_dtype = None + + # Pre-allocate buffers for FlashInfer + # qo_indptr: cumulative query/output lengths [batch_size + 1] + self._qo_indptr_buf = torch.zeros( + max_requests + 1, dtype=torch.int32, device=self.device + ) + + # paged_kv_indptr: cumulative block counts [batch_size + 1] + self._paged_kv_indptr_buf = torch.zeros( + max_requests + 1, dtype=torch.int32, device=self.device + ) + + # paged_kv_indices: flattened block indices [total_blocks] + max_total_blocks = max_requests * max_kv_block_count + self._paged_kv_indices_buf = torch.zeros( + max_total_blocks, dtype=torch.int32, device=self.device + ) + + # paged_kv_last_page_len: tokens in last page per request [batch_size] + self._paged_kv_last_page_len_buf = torch.zeros( + max_requests, dtype=torch.int32, device=self.device + ) + + # Workspace buffer for FlashInfer + self.workspace_buffer = torch.empty( + workspace_size, dtype=torch.uint8, device=self.device + ) + + # kv_seq_lengths buffer for cu_kv_lengths compatibility + self._kv_seq_lengths_buf = torch.zeros( + max_requests, dtype=torch.int32, device=self.device + ) + + # cu_kv_seq_lengths buffer for cu_kv_lengths compatibility + self._cu_kv_seq_lengths_buf = torch.zeros( + max_requests + 1, dtype=torch.int32, device=self.device + ) + + # block_table buffer for key_value_cache compatibility + self._block_table_buf = torch.zeros( + (max_requests, max_kv_block_count), dtype=torch.int32, device=self.device + ) + + # Runtime state + self._is_all_decode = False + self._batch_size = 0 + self._total_blocks = 0 + self._max_seqlen_q = 0 + self._max_seqlen_k = 0 + + def set_model_params( + self, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + params_dtype: torch.dtype, + ): + """ + Set model parameters needed for FlashInfer planning. + + Args: + num_qo_heads: Number of query/output heads + num_kv_heads: Number of key/value heads + head_dim: Dimension of each attention head + params_dtype: Data type for parameters (e.g., torch.float16) + """ + self._num_qo_heads = num_qo_heads + self._num_kv_heads = num_kv_heads + self._head_dim = head_dim + self._params_dtype = params_dtype + + @property + def prefill_wrapper(self): + """Return the prefill wrapper. Must be implemented by subclasses.""" + raise NotImplementedError + + @property + def decode_wrapper(self): + """Return the decode wrapper. Must be implemented by subclasses.""" + raise NotImplementedError + + def update( + self, + request_query_lengths: Tensor, + request_kv_length_offsets: Tensor, + request_to_kv_block_ids: Tensor, + batch_dimensions: InferenceBatchDimensions, + padded_batch_dimensions: InferenceBatchDimensions, + ): + """ + Update metadata from request states. + + Args: + request_query_lengths: Query token count per request (real_batch_size,) + request_kv_length_offsets: KV offset per request (real_batch_size,) + request_to_kv_block_ids: Block IDs per request (real_batch_size, max_kv_blocks) + batch_dimensions: Real batch dimensions + padded_batch_dimensions: Padded batch dimensions for CUDA graphs + """ + real_batch_size = batch_dimensions.req_count + padded_batch_size = padded_batch_dimensions.req_count + + self._batch_size = padded_batch_size + + # Determine if all requests are decode (query_length == 1 for all) + self._is_all_decode = ( + padded_batch_dimensions.prefill_req_count == 0 + and padded_batch_dimensions.decode_req_count > 0 + ) + + # Build qo_indptr: cumulative query lengths + self._qo_indptr_buf[0] = 0 + if real_batch_size > 0: + cumsum = torch.cumsum(request_query_lengths, dim=0) + self._qo_indptr_buf[1 : real_batch_size + 1] = cumsum + # Pad remaining entries + if padded_batch_size > real_batch_size: + last_val = cumsum[-1].item() + self._qo_indptr_buf[real_batch_size + 1 : padded_batch_size + 1] = last_val + + # Compute KV sequence lengths and block counts + kv_seq_lengths = request_kv_length_offsets + request_query_lengths + kv_block_counts = (kv_seq_lengths + self.block_size_tokens - 1) // self.block_size_tokens + + # Store kv_seq_lengths for cu_kv_lengths compatibility + if real_batch_size > 0: + self._kv_seq_lengths_buf[:real_batch_size] = kv_seq_lengths + # Pad remaining entries + if padded_batch_size > real_batch_size: + self._kv_seq_lengths_buf[real_batch_size:padded_batch_size] = 0 + + # Build cu_kv_seq_lengths: cumulative KV sequence lengths + self._cu_kv_seq_lengths_buf[0] = 0 + if real_batch_size > 0: + cumsum_kv = torch.cumsum(self._kv_seq_lengths_buf[:padded_batch_size], dim=0) + self._cu_kv_seq_lengths_buf[1 : padded_batch_size + 1] = cumsum_kv + + # Compute max_seqlen_q and max_seqlen_k + if padded_batch_dimensions.prefill_req_count == 0: + self._max_seqlen_q = 1 + else: + self._max_seqlen_q = max(2, request_query_lengths.max().item()) if real_batch_size > 0 else 1 + + self._max_seqlen_k = kv_seq_lengths.max().item() if real_batch_size > 0 else 0 + + # Build paged_kv_indptr: cumulative block counts + self._paged_kv_indptr_buf[0] = 0 + if real_batch_size > 0: + cumsum_blocks = torch.cumsum(kv_block_counts, dim=0) + self._paged_kv_indptr_buf[1 : real_batch_size + 1] = cumsum_blocks + self._total_blocks = cumsum_blocks[-1].item() + # Pad remaining entries + if padded_batch_size > real_batch_size: + self._paged_kv_indptr_buf[ + real_batch_size + 1 : padded_batch_size + 1 + ] = self._total_blocks + else: + self._total_blocks = 0 + + # Flatten block table to paged_kv_indices + if real_batch_size > 0 and self._total_blocks > 0: + # Extract valid block IDs from block table + idx = 0 + for i in range(real_batch_size): + num_blocks = kv_block_counts[i].item() + self._paged_kv_indices_buf[idx : idx + num_blocks] = request_to_kv_block_ids[ + i, :num_blocks + ] + idx += num_blocks + + # Compute last page lengths + if real_batch_size > 0: + last_page_lens = kv_seq_lengths - (kv_block_counts - 1) * self.block_size_tokens + self._paged_kv_last_page_len_buf[:real_batch_size] = last_page_lens + # Pad remaining entries + if padded_batch_size > real_batch_size: + self._paged_kv_last_page_len_buf[real_batch_size:padded_batch_size] = 1 + + # Store block table for key_value_cache compatibility + if real_batch_size > 0: + self._block_table_buf[:real_batch_size, :] = request_to_kv_block_ids[:real_batch_size, :] + if padded_batch_size > real_batch_size: + self._block_table_buf[real_batch_size:padded_batch_size, :] = -1 + + # Plan the appropriate wrapper + self._plan_wrapper(padded_batch_size) + + # Store state data + self.state_data = { + "qo_indptr": self._qo_indptr_buf[: padded_batch_size + 1], + "paged_kv_indptr": self._paged_kv_indptr_buf[: padded_batch_size + 1], + "paged_kv_indices": self._paged_kv_indices_buf[: self._total_blocks], + "paged_kv_last_page_len": self._paged_kv_last_page_len_buf[:padded_batch_size], + "is_all_decode": self._is_all_decode, + "batch_size": padded_batch_size, + # Compatibility keys for cu_query_lengths() and cu_kv_lengths() + "cu_query_seq_lengths": self._qo_indptr_buf[: padded_batch_size + 1], # alias to qo_indptr + "cu_kv_seq_lengths": self._cu_kv_seq_lengths_buf[: padded_batch_size + 1], + "kv_seq_lengths": self._kv_seq_lengths_buf[:padded_batch_size], + "max_seqlen_q": self._max_seqlen_q, + "max_seqlen_k": self._max_seqlen_k, + "block_table": self._block_table_buf[:padded_batch_size, :], + } + + def _plan_wrapper(self, batch_size: int): + """Plan the FlashInfer wrapper with current metadata.""" + if batch_size == 0: + return + + if self._is_all_decode: + # Use decode wrapper + self.decode_wrapper.plan( + indptr=self._paged_kv_indptr_buf[: batch_size + 1], + indices=self._paged_kv_indices_buf[: self._total_blocks], + last_page_len=self._paged_kv_last_page_len_buf[:batch_size], + num_qo_heads=self._num_qo_heads, + num_kv_heads=self._num_kv_heads, + head_dim=self._head_dim, + page_size=self.block_size_tokens, + q_data_type=self._params_dtype, + kv_data_type=self._params_dtype, + block_tables=self._block_table_buf[:batch_size, :], + ) + else: + # Use prefill wrapper (uses head_dim_qk instead of head_dim) + self.prefill_wrapper.plan( + qo_indptr=self._qo_indptr_buf[: batch_size + 1], + paged_kv_indptr=self._paged_kv_indptr_buf[: batch_size + 1], + paged_kv_indices=self._paged_kv_indices_buf[: self._total_blocks], + paged_kv_last_page_len=self._paged_kv_last_page_len_buf[:batch_size], + num_qo_heads=self._num_qo_heads, + num_kv_heads=self._num_kv_heads, + head_dim_qk=self._head_dim, + page_size=self.block_size_tokens, + q_data_type=self._params_dtype, + kv_data_type=self._params_dtype, + causal=True, + block_tables=self._block_table_buf[:batch_size, :], + ) + + def attention( + self, + q: Tensor, + kv_cache: Tensor, + softmax_scale: float = None, + layer_idx: int = 0, + ) -> Tensor: + """ + Run FlashInfer attention. + + Args: + q: Query tensor of shape (batch, 1, num_heads, head_dim) or (tokens, num_heads, head_dim) + kv_cache: KV cache tensor with layout M_N2HCD [blocks, 2, num_kv_heads, block_size, head_dim] + softmax_scale: Optional softmax scale + layer_idx: Layer index (unused for now) + + Returns: + Output tensor of shape (batch, 1, num_heads, head_dim) + """ + # Squeeze sequence dimension if present: (batch, 1, heads, dim) -> (batch, heads, dim) + if q.dim() == 4: + q_input = q.squeeze(1) + else: + q_input = q + + # Run appropriate wrapper + if self._is_all_decode: + output = self.decode_wrapper.run(q_input, kv_cache) + else: + output = self.prefill_wrapper.run(q_input, kv_cache) + + # Restore sequence dimension: (batch, heads, dim) -> (batch, 1, heads, dim) + return output.unsqueeze(1) + + def reset(self): + """Reset metadata for next batch.""" + self._qo_indptr_buf.fill_(0) + self._paged_kv_indptr_buf.fill_(0) + self._paged_kv_indices_buf.fill_(0) + self._paged_kv_last_page_len_buf.fill_(0) + self._kv_seq_lengths_buf.fill_(0) + self._cu_kv_seq_lengths_buf.fill_(0) + self._block_table_buf.fill_(0) + self._is_all_decode = False + self._batch_size = 0 + self._total_blocks = 0 + self._max_seqlen_q = 0 + self._max_seqlen_k = 0 + self.state_data = {} + + +class GraphFlashInferMetadata(FlashInferMetadata): + """ + FlashInfer metadata for CUDA graph mode. + + Pre-binds buffers to wrappers and caches wrappers per batch size. + """ + + def __init__( + self, + max_requests: int, + max_kv_block_count: int, + block_size_tokens: int, + backend: AttnBackend, + workspace_size: int = 512 * 1024 * 1024, + ): + super().__init__( + max_requests=max_requests, + max_kv_block_count=max_kv_block_count, + block_size_tokens=block_size_tokens, + backend=backend, + workspace_size=workspace_size, + ) + + # Cache wrappers per batch size + self._prefill_wrappers_by_bs = {} + self._decode_wrappers_by_bs = {} + + def _get_prefill_wrapper(self, batch_size: int): + """Get or create prefill wrapper for given batch size.""" + if batch_size not in self._prefill_wrappers_by_bs: + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "HND", + use_cuda_graph=True, + qo_indptr_buf=self._qo_indptr_buf[: batch_size + 1], + paged_kv_indptr_buf=self._paged_kv_indptr_buf[: batch_size + 1], + paged_kv_indices_buf=self._paged_kv_indices_buf, + paged_kv_last_page_len_buf=self._paged_kv_last_page_len_buf[:batch_size], + backend=self.flashinfer_backend, + ) + self._prefill_wrappers_by_bs[batch_size] = wrapper + return self._prefill_wrappers_by_bs[batch_size] + + def _get_decode_wrapper(self, batch_size: int): + """Get or create decode wrapper for given batch size.""" + if batch_size not in self._decode_wrappers_by_bs: + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "HND", + use_tensor_cores=True, + use_cuda_graph=True, + paged_kv_indptr_buffer=self._paged_kv_indptr_buf[: batch_size + 1], + paged_kv_indices_buffer=self._paged_kv_indices_buf, + paged_kv_last_page_len_buffer=self._paged_kv_last_page_len_buf[:batch_size], + backend=self.flashinfer_backend, + ) + self._decode_wrappers_by_bs[batch_size] = wrapper + return self._decode_wrappers_by_bs[batch_size] + + @property + def prefill_wrapper(self): + return self._get_prefill_wrapper(self._batch_size) + + @property + def decode_wrapper(self): + return self._get_decode_wrapper(self._batch_size) + + +class NonGraphFlashInferMetadata(FlashInferMetadata): + """ + FlashInfer metadata for non-CUDA graph mode. + + Creates wrappers lazily without buffer binding. + """ + + def __init__( + self, + max_requests: int, + max_kv_block_count: int, + block_size_tokens: int, + backend: AttnBackend, + workspace_size: int = 512 * 1024 * 1024, + ): + super().__init__( + max_requests=max_requests, + max_kv_block_count=max_kv_block_count, + block_size_tokens=block_size_tokens, + backend=backend, + workspace_size=workspace_size, + ) + + # Lazy wrapper initialization + self._prefill_wrapper = None + self._decode_wrapper = None + + @property + def prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "HND", + use_cuda_graph=False, + backend=self.flashinfer_backend, + ) + return self._prefill_wrapper + + @property + def decode_wrapper(self): + if self._decode_wrapper is None: + self._decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "HND", + use_tensor_cores=True, + use_cuda_graph=False, + backend=self.flashinfer_backend, + ) + return self._decode_wrapper diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index f66d6f2d4c..782c5f20ec 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -37,6 +37,15 @@ from .attention_context.mamba_metadata import MambaInferenceStateConfig, MambaMetadata from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata from .base_context import BaseInferenceContext + +try: + from .attention_context.flashinfer_metadata import ( + GraphFlashInferMetadata, + NonGraphFlashInferMetadata, + ) + HAVE_FLASHINFER_METADATA = True +except ImportError: + HAVE_FLASHINFER_METADATA = False from .dynamic_block_allocator import BlockAllocator from ..kv_cache import KVCacheBase, KVCacheLayout, MLACache, create_mhagqa_cache @@ -209,7 +218,8 @@ class DynamicInferenceContext(BaseInferenceContext): params_dtype (torch.dtype): Dtype used for KV cache. num_layers (int): Number of layers on this pipeline parallel rank. kv_channels (int): Hidden dimension per attention head. - num_attention_heads (int): Number of attention heads. + num_attention_kv_heads (int): Number of key/value attention heads. + num_attention_qo_heads (int): Number of query/output attention heads. max_sequence_length (int): Max possible sequence length (prompt + output) that will occur. buffer_size_gb (float): Buffer size reserved on the GPU for the KV cache. @@ -249,7 +259,7 @@ class DynamicInferenceContext(BaseInferenceContext): label, the target dtype, and whether to store the data on GPU. """ - DEFAULT_MAX_TOKENS = 16384 + DEFAULT_MAX_TOKENS = 32768 TOKEN_ROUNDER = 64 REQUEST_ROUNDER = 4 @@ -259,7 +269,8 @@ def __init__( params_dtype: torch.dtype, num_layers: int, kv_channels: int, - num_attention_heads: int, + num_attention_kv_heads: int, + num_attention_qo_heads: int, max_sequence_length: int, buffer_size_gb: float, max_requests: int = None, @@ -302,13 +313,14 @@ def __init__( self.metrics_writer = metrics_writer # Per partition num heads and hidden size. - projection_size = kv_channels * num_attention_heads + projection_size = kv_channels * num_attention_qo_heads if tensor_model_parallel_size is None: tp_size = parallel_state.get_tensor_model_parallel_world_size() else: tp_size = tensor_model_parallel_size - self.hidden_size_per_attention_head = core_divide(projection_size, num_attention_heads) - self.num_attention_heads_per_partition = core_divide(num_attention_heads, tp_size) + self.hidden_size_per_attention_head = core_divide(projection_size, num_attention_qo_heads) + self.num_attention_kv_heads_per_partition = core_divide(num_attention_kv_heads, tp_size) + self.num_attention_qo_heads_per_partition = core_divide(num_attention_qo_heads, tp_size) # Mamba states. self.is_hybrid_model = mamba_inference_state_config is not None @@ -366,7 +378,7 @@ def __init__( * 2 # key, value * self.num_attention_layers * self.block_size_tokens - * self.num_attention_heads_per_partition + * self.num_attention_kv_heads_per_partition * self.hidden_size_per_attention_head ) assert self.block_size_bytes > 0 @@ -479,7 +491,7 @@ def __init__( CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=tp_size, num_cuda_graphs=num_cuda_graphs, - cuda_graph_max_tokens=self.max_active_requests, + cuda_graph_max_tokens=1024, #self.max_active_requests, cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count, max_requests=self.max_active_requests, max_tokens=self.max_tokens, @@ -618,7 +630,7 @@ def allocate_memory_buffer(): layout=self._cache_layout, num_chunks=self.block_allocator.total_count, chunk_size=self.block_size_tokens, - num_kv_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_kv_heads_per_partition, head_dim=self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device(), @@ -656,6 +668,66 @@ def allocate_mamba_states(): allocate_memory_buffer() allocate_mamba_states() + # Initialize attention metadata based on backend. + self.graph_attn_metadata = {} + self.non_graph_attn_metadata = {} + self.active_attn_metadata = None + + if self.attention_backend in [ + AttnBackend.flashinfer_fa2, + AttnBackend.flashinfer_fa3, + AttnBackend.flashinfer_trt, + ]: + # FlashInfer backends use FlashInferMetadata. + if not HAVE_FLASHINFER_METADATA: + raise ImportError( + "FlashInfer metadata is required for FlashInfer attention backends. " + "Please ensure flashinfer is installed." + ) + + self.graph_attn_metadata["mha_metadata"] = GraphFlashInferMetadata( + max_requests=self.max_active_requests, + max_kv_block_count=self.max_kv_block_count, + block_size_tokens=self.block_size_tokens, + backend=self.attention_backend, + ) + + self.non_graph_attn_metadata["mha_metadata"] = NonGraphFlashInferMetadata( + max_requests=self.max_active_requests, + max_kv_block_count=self.max_kv_block_count, + block_size_tokens=self.block_size_tokens, + backend=self.attention_backend, + ) + + # Set model parameters for FlashInfer planning. + for metadata in [ + self.graph_attn_metadata["mha_metadata"], + self.non_graph_attn_metadata["mha_metadata"], + ]: + metadata.set_model_params( + num_qo_heads=self.num_attention_qo_heads_per_partition, + num_kv_heads=self.num_attention_kv_heads_per_partition, + head_dim=self.hidden_size_per_attention_head, + params_dtype=self.params_dtype, + ) + else: + # Default: Flash Attention backends use MHAMetadata. + self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata( + block_count_total=self.block_allocator.total_count, + max_kv_block_count=self.max_kv_block_count, + max_requests=self.max_active_requests, + block_size_tokens=self.block_size_tokens, + max_seqlen=self.max_sequence_length, + ) + + self.non_graph_attn_metadata["mha_metadata"] = NonGraphedMHAMetadata( + block_count_total=self.block_allocator.total_count, + max_kv_block_count=self.max_kv_block_count, + max_requests=self.max_active_requests, + block_size_tokens=self.block_size_tokens, + max_seqlen=self.max_sequence_length, + ) + # Reset attention and Mamba state. self.reset_attention_state() self.reset_mamba_state() @@ -730,7 +802,8 @@ def from_config( params_dtype=inference_config.params_dtype, num_layers=model_config.num_layers // model_config.pipeline_model_parallel_size, kv_channels=model_config.kv_channels, - num_attention_heads=model_config.num_query_groups, + num_attention_kv_heads=model_config.num_query_groups, + num_attention_qo_heads=model_config.num_attention_heads, max_sequence_length=inference_config.inference_max_seq_length, buffer_size_gb=buffer_size_gb, materialize_only_last_token_logits=False, @@ -848,6 +921,22 @@ def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) token_to_local_position_within_kv_block=self.token_to_local_position_within_kv_block, ) + def get_kv_cache_content(self, layer_number: int) -> Tensor: + """Get raw KV cache content tensor for a layer. + + This method handles the layer_map for hybrid models and returns + the raw cache tensor directly (for use with FlashInfer). + + Args: + layer_number (int): Layer number (1-based). + + Return: + (Tensor) The raw KV cache content tensor. + """ + attention_layer_number = self.layer_map[layer_number - 1] + cache = self.memory_buffer[attention_layer_number] + return cache.get_content() + def key_value_cache(self, layer_number: int) -> Tuple[Tensor, Optional[Tensor], Tensor]: """Read from KV cache. diff --git a/megatron/core/inference/contexts/fused_kv_append_kernel.py b/megatron/core/inference/contexts/fused_kv_append_kernel.py index 42b3a59e44..f3ee253763 100644 --- a/megatron/core/inference/contexts/fused_kv_append_kernel.py +++ b/megatron/core/inference/contexts/fused_kv_append_kernel.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -from typing import Optional +from typing import Optional, Tuple import triton import triton.language as tl diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 2f26edb8c3..493cac8c57 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -41,7 +41,7 @@ from ..models.common.embeddings.yarn_rotary_pos_embedding import ( _yarn_get_concentration_factor_from_config, ) -from .enums import AttnMaskType +from .enums import AttnBackend, AttnMaskType from .transformer_config import TransformerConfig try: @@ -965,21 +965,36 @@ def forward( else: # Dynamic batching attention kernel. q, k, v = (query, key, value) - cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() - cu_kv_lengths, kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() - core_attn_out = self.flash_decode_and_prefill( - q, - k, - v, - max_seqlen_q, - max_seqlen_k, - cu_query_lengths, - cu_kv_lengths, - kv_lengths, - block_table, - ) - core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + if inference_context.attention_backend in [ + AttnBackend.flashinfer_fa2, + AttnBackend.flashinfer_fa3, + AttnBackend.flashinfer_trt, + ]: + # FlashInfer attention using KV cache directly + pp_layer_offset = self._get_pp_layer_offset_for_inference() + layer_number_adjusted = self.layer_number - pp_layer_offset + kv_cache = inference_context.get_kv_cache_content(layer_number_adjusted) + core_attn_out = inference_context.active_attn_metadata[ + "mha_metadata" + ].attention(q, kv_cache, layer_idx=layer_number_adjusted - 1) + core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + else: + cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths() + cu_kv_lengths, kv_lengths, max_seqlen_k = inference_context.cu_kv_lengths() + + core_attn_out = self.flash_decode_and_prefill( + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_query_lengths, + cu_kv_lengths, + kv_lengths, + block_table, + ) + core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case diff --git a/megatron/rl/inference/megatron.py b/megatron/rl/inference/megatron.py index e67900e20a..7a30786aff 100644 --- a/megatron/rl/inference/megatron.py +++ b/megatron/rl/inference/megatron.py @@ -134,9 +134,10 @@ def get_dynamic_inference_engine( params_dtype=args.params_dtype, num_layers=args.num_layers // args.pipeline_model_parallel_size, kv_channels=args.kv_channels, - num_attention_heads=( + num_attention_kv_heads=( args.num_query_groups if args.group_query_attention else args.num_attention_heads ), + num_attention_qo_heads=args.num_attention_heads, max_sequence_length=args.inference_max_seq_length, num_cuda_graphs=( args.inference_dynamic_batching_num_cuda_graphs if enable_cuda_graph else None diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 2da334191a..f386d81499 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -70,7 +70,8 @@ def _get_dynamic_context( params_dtype=params_dtype, num_layers=num_layers // self.pp_size, kv_channels=kv_channels, - num_attention_heads=num_attention_heads, + num_attention_kv_heads=num_attention_heads, + num_attention_qo_heads=num_attention_heads, max_sequence_length=max_sequence_length, num_cuda_graphs=None, use_cuda_graphs_for_non_decode_steps=not is_hybrid_model, diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index ef6252094a..7d8f930fcc 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -218,7 +218,8 @@ def _build_inference_context( num_layers=transformer_config.num_layers // transformer_config.pipeline_model_parallel_size, kv_channels=transformer_config.kv_channels, - num_attention_heads=transformer_config.num_query_groups, + num_attention_kv_heads=transformer_config.num_query_groups, + num_attention_qo_heads=transformer_config.num_attention_heads, max_sequence_length=test_config.max_sequence_length, num_cuda_graphs=test_config.num_cuda_graphs, use_cuda_graphs_for_non_decode_steps=not test_config.model_provider == "mamba", diff --git a/tests/unit_tests/inference/test_wandb_logging.py b/tests/unit_tests/inference/test_wandb_logging.py index 1d5d054b80..0f4262a973 100644 --- a/tests/unit_tests/inference/test_wandb_logging.py +++ b/tests/unit_tests/inference/test_wandb_logging.py @@ -57,7 +57,8 @@ def _get_dynamic_context( params_dtype=params_dtype, num_layers=num_layers, kv_channels=kv_channels, - num_attention_heads=num_attention_heads, + num_attention_kv_heads=num_attention_heads, + num_attention_qo_heads=num_attention_heads, max_sequence_length=max_sequence_length, num_cuda_graphs=None, buffer_size_gb=buffer_size_gb, @@ -232,7 +233,8 @@ def test_paused_requests_in_stats(self): params_dtype=torch.float32, num_layers=2, kv_channels=64, - num_attention_heads=8, + num_attention_kv_heads=8, + num_attention_qo_heads=8, max_sequence_length=128, num_cuda_graphs=None, buffer_size_gb=0.01, # Small buffer to force pausing diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index ebf558d3fa..1cd2c75f83 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -117,7 +117,8 @@ def setup_model( params_dtype=dtype, num_layers=transformer_config.num_layers // pipeline_model_parallel_size, kv_channels=transformer_config.kv_channels, - num_attention_heads=transformer_config.num_attention_heads, + num_attention_kv_heads=transformer_config.num_query_groups, + num_attention_qo_heads=transformer_config.num_attention_heads, max_sequence_length=2048, buffer_size_gb=0.2, materialize_only_last_token_logits=False, diff --git a/tests/unit_tests/models/test_gpt_model_batch_invariant.py b/tests/unit_tests/models/test_gpt_model_batch_invariant.py index ead9125e5e..5a9572e62f 100644 --- a/tests/unit_tests/models/test_gpt_model_batch_invariant.py +++ b/tests/unit_tests/models/test_gpt_model_batch_invariant.py @@ -187,7 +187,8 @@ def test_dynamic_engine_matches_batched_forward_rl(self): params_dtype=torch.bfloat16, num_layers=base_model.config.num_layers, kv_channels=base_model.config.kv_channels, - num_attention_heads=base_model.config.num_attention_heads, + num_attention_kv_heads=base_model.config.num_query_groups, + num_attention_qo_heads=base_model.config.num_attention_heads, max_sequence_length=seq_len, buffer_size_gb=0.125, block_size_tokens=16, @@ -276,7 +277,8 @@ def _run_engine_with_order(order): params_dtype=torch.bfloat16, num_layers=base_model.config.num_layers, kv_channels=base_model.config.kv_channels, - num_attention_heads=base_model.config.num_attention_heads, + num_attention_kv_heads=base_model.config.num_query_groups, + num_attention_qo_heads=base_model.config.num_attention_heads, max_sequence_length=seq_len, buffer_size_gb=0.125, block_size_tokens=16, diff --git a/tools/run_inference_performance_test.py b/tools/run_inference_performance_test.py index dda2b8284b..c42159bb1d 100644 --- a/tools/run_inference_performance_test.py +++ b/tools/run_inference_performance_test.py @@ -104,9 +104,10 @@ def get_inference_engine(args: argparse.Namespace, model: MegatronModule) -> Abs params_dtype=args.params_dtype, num_layers=args.num_layers, kv_channels=args.kv_channels, - num_attention_heads=( + num_attention_kv_heads=( args.num_query_groups if args.group_query_attention else args.num_attention_heads ), + num_attention_qo_heads=args.num_attention_heads, max_sequence_length=args.inference_max_seq_length, num_cuda_graphs=( args.inference_dynamic_batching_num_cuda_graphs