diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 85bd327cfc6..5cee37b19cf 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -354,7 +354,7 @@ EOF fi ;; qwen3_5_moe) - RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 32" + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'What is the capital of France?' --max_new_tokens 128 --temperature 0" ;; voxtral_realtime) RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 5acb6efe652..68ded356b99 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -145,8 +145,8 @@ jobs: # Run CUDA backend Python tests python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts=" - # Run quantize roundtrip tests (Qwen 3.5 MoE save/load prequantized) - python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py -v -o "addopts=" + # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache) + python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts=" export-model-cuda-artifact: name: export-model-cuda-artifact diff --git a/backends/aoti/common_shims_slim.cpp b/backends/aoti/common_shims_slim.cpp index e4378433e8a..739b3ee68c0 100644 --- a/backends/aoti/common_shims_slim.cpp +++ b/backends/aoti/common_shims_slim.cpp @@ -134,6 +134,10 @@ int32_t aoti_torch_dtype_int8() { return 1; // ScalarType::Char } +int32_t aoti_torch_dtype_uint8() { + return 0; // ScalarType::Byte +} + int32_t aoti_torch_dtype_bool() { return 11; // ScalarType::Bool } diff --git a/backends/aoti/common_shims_slim.h b/backends/aoti/common_shims_slim.h index a98dd765978..75ede847d5a 100644 --- a/backends/aoti/common_shims_slim.h +++ b/backends/aoti/common_shims_slim.h @@ -76,6 +76,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_uint8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); // ============================================================ diff --git a/backends/aoti/slim/c10/core/ScalarType.h b/backends/aoti/slim/c10/core/ScalarType.h index 3acf2bebc48..c1499a83f39 100644 --- a/backends/aoti/slim/c10/core/ScalarType.h +++ b/backends/aoti/slim/c10/core/ScalarType.h @@ -23,7 +23,7 @@ using BFloat16 = ::executorch::runtime::etensor::BFloat16; /// Enum representing the scalar type (dtype) of tensor elements. /// Note: Enum values must match PyTorch's c10::ScalarType for compatibility. enum class ScalarType : int8_t { - // Byte = 0, // uint8_t - not currently needed + Byte = 0, // uint8_t Char = 1, // int8_t Short = 2, // int16_t Int = 3, // int32_t @@ -43,6 +43,7 @@ enum class ScalarType : int8_t { }; // Type alias constants for convenience +constexpr ScalarType kByte = ScalarType::Byte; constexpr ScalarType kChar = ScalarType::Char; constexpr ScalarType kShort = ScalarType::Short; constexpr ScalarType kInt = ScalarType::Int; @@ -56,6 +57,8 @@ constexpr ScalarType kBFloat16 = ScalarType::BFloat16; /// @return The size in bytes of a single element. inline size_t elementSize(ScalarType t) { switch (t) { + case ScalarType::Byte: + return sizeof(uint8_t); case ScalarType::Char: return sizeof(int8_t); case ScalarType::Short: @@ -80,6 +83,8 @@ inline size_t elementSize(ScalarType t) { /// @return The name of the scalar type. inline const char* toString(ScalarType t) { switch (t) { + case ScalarType::Byte: + return "Byte"; case ScalarType::Char: return "Char"; case ScalarType::Short: @@ -114,6 +119,7 @@ inline bool isFloatingType(ScalarType t) { /// @return true if the scalar type is integral, false otherwise. inline bool isIntegralType(ScalarType t, bool includeBool) { switch (t) { + case ScalarType::Byte: case ScalarType::Char: case ScalarType::Short: case ScalarType::Int: @@ -138,6 +144,7 @@ inline bool isBoolType(ScalarType t) { /// @return true if the scalar type is valid, false otherwise. inline bool isValidScalarType(ScalarType t) { switch (t) { + case ScalarType::Byte: case ScalarType::Char: case ScalarType::Short: case ScalarType::Int: diff --git a/backends/cuda/tests/test_tq4_sdpa.py b/backends/cuda/tests/test_tq4_sdpa.py new file mode 100644 index 00000000000..9cf1e9e2d57 --- /dev/null +++ b/backends/cuda/tests/test_tq4_sdpa.py @@ -0,0 +1,802 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for the TQ4 fused SDPA kernel (tq4_sdpa). + +Verifies that attention over nibble-packed TQ4-compressed K/V cache +matches a reference path: decompress K/V → standard SDPA in float32. + +Test structure follows test_triton_sdpa.py. +""" + +import os +import subprocess +import tempfile +import unittest + +import numpy as np +import torch +import torch.nn.functional as F + +from executorch.backends.cuda.cuda_backend import CudaBackend +from executorch.backends.cuda.cuda_partitioner import CudaPartitioner +from executorch.backends.cuda.triton.kernels.tq4_sdpa import tq4_sdpa +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass +from executorch.extension.llm.modules.turboquant import TurboQuantKVCache +from executorch.extension.llm.modules.turboquant.codebook import ( + generate_rotation_matrix, + solve_lloyd_max, +) +from torch.export import export + + +def _skip_if_no_cuda(): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available") + if not torch.cuda.is_bf16_supported(): + raise unittest.SkipTest("BF16 not supported on this GPU") + + +def _make_codebook_and_rotation(head_dim, bits=4, seed=42): + """Precompute TQ4 constants.""" + centroids, boundaries = solve_lloyd_max(head_dim, bits) + rotation = generate_rotation_matrix(head_dim, seed=seed) + return centroids, boundaries, rotation + + +def _compress(x, boundaries, rotation): + """Compress (B, H, S, D) tensor to nibble-packed uint8 + fp32 norms.""" + B, H, S, D = x.shape + flat = x.reshape(-1, D).float() + + norms = torch.linalg.vector_norm(flat, dim=-1, keepdim=True) + normalized = flat / (norms + 1e-10) + rotated = normalized @ rotation.float().T + indices = torch.bucketize(rotated, boundaries.float()) + + idx_u8 = indices.to(torch.uint8) + packed = (idx_u8[:, 0::2] << 4) | idx_u8[:, 1::2] + + return packed.reshape(B, H, S, D // 2), norms.reshape(B, H, S, 1) + + +def _decompress(packed, norms, centroids, rotation): + """Decompress nibble-packed uint8 + fp32 norms to float tensor.""" + B, H, S, half_D = packed.shape + D = half_D * 2 + flat = packed.reshape(-1, half_D) + flat_norms = norms.reshape(-1, 1) + + high = (flat >> 4).long() + low = (flat & 0x0F).long() + indices = torch.stack([high, low], dim=-1).reshape(-1, D) + + reconstructed = centroids.float()[indices] + unrotated = reconstructed @ rotation.float() + scaled = unrotated * flat_norms + + return scaled.reshape(B, H, S, D) + + +def _reference_tq4_sdpa( + q, k, v, centroids, boundaries, rotation, attn_mask=None, is_causal=False +): + """Reference: compress K/V, decompress, run standard SDPA in float32.""" + k_packed, k_norms = _compress(k, boundaries, rotation) + v_packed, v_norms = _compress(v, boundaries, rotation) + + k_dec = _decompress(k_packed, k_norms, centroids, rotation) + v_dec = _decompress(v_packed, v_norms, centroids, rotation) + + H_q = q.shape[1] + H_kv = k.shape[1] + if H_q != H_kv: + k_dec = k_dec.repeat_interleave(H_q // H_kv, dim=1) + v_dec = v_dec.repeat_interleave(H_q // H_kv, dim=1) + + if attn_mask is not None and attn_mask.shape[1] == 1 and H_q > 1: + attn_mask = attn_mask.expand(-1, H_q, -1, -1) + + out = F.scaled_dot_product_attention( + q.float(), + k_dec.float(), + v_dec.float(), + attn_mask=attn_mask, + is_causal=is_causal, + ) + return out, k_packed, k_norms, v_packed, v_norms + + +def _cosine_sim(a, b): + return F.cosine_similarity( + a.reshape(-1).float(), b.reshape(-1).float(), dim=0 + ).item() + + +# Test configs +HEAD_DIMS = [64, 128, 256] +SEQLEN_PAIRS = [ + (1, 64), + (1, 128), + (4, 64), + (64, 64), + (128, 128), +] +GQA_CONFIGS = [ + (4, 4, "mha"), + (4, 2, "gqa_2x"), + (8, 2, "gqa_4x"), + (16, 2, "gqa_8x"), + (6, 1, "mqa"), +] + + +class TestTQ4Sdpa(unittest.TestCase): + """Test TQ4 fused SDPA kernel against decompress-then-SDPA reference.""" + + @classmethod + def setUpClass(cls): + _skip_if_no_cuda() + cls.tq4_sdpa = tq4_sdpa + + def _run_test( + self, B, H_q, H_kv, Lq, Lk, D, attn_mask=None, is_causal=False, min_cosine=0.95 + ): + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids = centroids.cuda() + boundaries = boundaries.cuda() + rotation = rotation.cuda() + + q = torch.randn(B, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(B, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + + ref_out, k_packed, k_norms, v_packed, v_norms = _reference_tq4_sdpa( + q, + k, + v, + centroids, + boundaries, + rotation, + attn_mask=attn_mask, + is_causal=is_causal, + ) + + out = self.tq4_sdpa( + q, + k_packed.cuda(), + k_norms.cuda(), + v_packed.cuda(), + v_norms.cuda(), + centroids, + rotation, + attn_mask=attn_mask, + is_causal=is_causal, + ) + + self.assertFalse(torch.isnan(out).any(), "NaN in output") + cos = _cosine_sim(out, ref_out) + self.assertGreater( + cos, + min_cosine, + f"Cosine {cos:.4f} < {min_cosine} " + f"(B={B} H_q={H_q} H_kv={H_kv} Lq={Lq} Lk={Lk} D={D})", + ) + + # ------------------------------------------------------------------ + # MHA (H_q == H_kv) + # ------------------------------------------------------------------ + + def test_mha_basic(self): + """MHA with various head dims and sequence lengths.""" + for D in HEAD_DIMS: + for Lq, Lk in SEQLEN_PAIRS: + if Lq > Lk: + continue + with self.subTest(D=D, Lq=Lq, Lk=Lk): + self._run_test(1, 4, 4, Lq, Lk, D) + + def test_mha_causal(self): + """MHA with is_causal=True.""" + for D in [64, 128, 256]: + for L in [64, 128]: + with self.subTest(D=D, L=L): + self._run_test(1, 4, 4, L, L, D, is_causal=True) + + def test_mha_causal_explicit_mask(self): + """MHA with causal masking via explicit bool mask.""" + for D in [64, 128, 256]: + for L in [64, 128]: + with self.subTest(D=D, L=L): + mask = torch.tril( + torch.ones(1, 1, L, L, dtype=torch.bool, device="cuda") + ) + self._run_test(1, 4, 4, L, L, D, attn_mask=mask) + + def test_mha_bool_mask(self): + """MHA with explicit bool attention mask.""" + D = 64 + for Lq, Lk in [(1, 64), (4, 128), (1, 256)]: + with self.subTest(Lq=Lq, Lk=Lk): + mask = torch.zeros(1, 1, Lq, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, : Lk // 2] = True + self._run_test(1, 4, 4, Lq, Lk, D, attn_mask=mask) + + # ------------------------------------------------------------------ + # GQA (H_q > H_kv) + # ------------------------------------------------------------------ + + def test_gqa_decode(self): + """GQA decode (seqlen_q=1).""" + for H_q, H_kv, label in GQA_CONFIGS: + if H_q == H_kv: + continue + for D in [64, 128, 256]: + with self.subTest(label=label, D=D): + self._run_test(1, H_q, H_kv, 1, 128, D) + + def test_gqa_prefill(self): + """GQA prefill (seqlen_q > 1) with is_causal=True.""" + for H_q, H_kv, label in GQA_CONFIGS: + if H_q == H_kv: + continue + with self.subTest(label=label): + self._run_test(1, H_q, H_kv, 64, 64, 128, is_causal=True) + + def test_gqa_8x_head_dim_256(self): + """GQA 8:1 with head_dim=256 — matches Qwen 3.5 MoE config.""" + self._run_test(1, 16, 2, 1, 128, 256) + L = 64 + mask = torch.tril(torch.ones(1, 1, L, L, dtype=torch.bool, device="cuda")) + self._run_test(1, 16, 2, L, L, 256, attn_mask=mask) + + def test_gqa_with_mask(self): + """GQA decode with explicit bool mask.""" + D, Lk = 128, 128 + mask = torch.ones(1, 1, 1, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, Lk // 2 :] = False + self._run_test(1, 8, 2, 1, Lk, D, attn_mask=mask) + + # ------------------------------------------------------------------ + # Edge cases + # ------------------------------------------------------------------ + + def test_batch_size_2(self): + """Batch size > 1.""" + self._run_test(2, 4, 2, 1, 64, 128) + + def test_short_kv(self): + """Short KV sequence (32 tokens).""" + self._run_test(1, 4, 4, 1, 32, 64) + + def test_all_masked_produces_zeros(self): + """Fully masked Q rows produce zero output, not NaN.""" + D = 64 + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, rotation = centroids.cuda(), rotation.cuda() + + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + + k_packed, k_norms = _compress(k, boundaries.cuda(), rotation) + v_packed, v_norms = _compress(v, boundaries.cuda(), rotation) + + # All-False mask: every KV position is masked out + mask = torch.zeros(1, 1, 1, 64, dtype=torch.bool, device="cuda") + out = self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + mask, + ) + self.assertFalse(torch.isnan(out).any(), "NaN in output with all-masked row") + self.assertFalse(torch.isinf(out).any(), "Inf in output with all-masked row") + self.assertEqual(out.abs().max().item(), 0.0) + + def test_sparse_mask_no_nan(self): + """Sparse mask with many all-masked tile blocks produces no NaN/Inf. + + Only a few KV positions are unmasked, so most tile blocks are entirely + masked (-inf). The softmax must not produce NaN from -inf - (-inf) + and must not propagate it into subsequent valid blocks. + """ + D, Lk = 64, 256 + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + + q = torch.randn(1, 4, 4, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, 4, Lk, D, dtype=torch.bfloat16, device="cuda") + + k_packed, k_norms = _compress(k, boundaries, rotation) + v_packed, v_norms = _compress(v, boundaries, rotation) + + # Sparse: only positions 100-103 unmasked, rest masked. + mask = torch.zeros(1, 1, 4, Lk, dtype=torch.bool, device="cuda") + mask[:, :, :, 100:104] = True + + out = self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + mask, + ) + self.assertFalse(torch.isnan(out).any(), "NaN with sparse mask") + self.assertFalse(torch.isinf(out).any(), "Inf with sparse mask") + self.assertGreater(out.abs().max().item(), 0, "Output is all zeros") + + def test_float_mask_rejected(self): + """Float attention mask raises RuntimeError.""" + D = 64 + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, rotation = centroids.cuda(), rotation.cuda() + + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + k_packed, k_norms = _compress(k, boundaries.cuda(), rotation) + v_packed, v_norms = _compress(v, boundaries.cuda(), rotation) + + float_mask = torch.zeros(1, 1, 1, 64, dtype=torch.float32, device="cuda") + with self.assertRaises(RuntimeError): + self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + centroids, + rotation, + float_mask, + ) + + def test_qwen35_moe_config(self): + """Qwen 3.5 MoE: head_dim=256, GQA 16:2, decode + prefill.""" + self._run_test(1, 16, 2, 1, 256, 256) + self._run_test(1, 16, 2, 128, 128, 256, is_causal=True) + + def test_mqa(self): + """MQA (all Q heads share 1 KV head).""" + for D in [64, 128]: + with self.subTest(D=D): + self._run_test(1, 6, 1, 1, 128, D) + + def test_gqa_short_seqlen(self): + """GQA with short seqlen_q (2-8), exercises Pack GQA boundary.""" + for Lq in [2, 4, 8]: + with self.subTest(Lq=Lq): + self._run_test(1, 8, 2, Lq, 128, 128) + + def test_gqa_long_kv(self): + """GQA decode with longer KV sequences.""" + for Lk in [512, 1024]: + with self.subTest(Lk=Lk): + self._run_test(1, 16, 2, 1, Lk, 128) + + def test_gqa_causal_decode_with_cache_mask(self): + """GQA decode with KV cache mask at various fill levels.""" + H_q, H_kv, D = 16, 2, 128 + for cache_len in [64, 256, 512]: + with self.subTest(cache_len=cache_len): + pos = cache_len * 3 // 4 + mask = torch.zeros(1, 1, 1, cache_len, dtype=torch.bool, device="cuda") + mask[:, :, :, :pos] = True + self._run_test(1, H_q, H_kv, 1, cache_len, D, attn_mask=mask) + + def test_output_shape_and_dtype(self): + """Output shape and dtype are correct for various configs.""" + for H_q, H_kv in [(4, 4), (8, 2), (6, 1)]: + for Lq, Lk in [(1, 64), (32, 64)]: + with self.subTest(H_q=H_q, H_kv=H_kv, Lq=Lq, Lk=Lk): + D = 64 + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(1, H_q, Lq, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, H_kv, Lk, D, dtype=torch.bfloat16, device="cuda") + k_p, k_n = _compress(k, boundaries, rotation) + v_p, v_n = _compress(v, boundaries, rotation) + out = self.tq4_sdpa( + q, + k_p, + k_n, + v_p, + v_n, + centroids, + rotation, + ) + self.assertEqual(out.shape, (1, H_q, Lq, D)) + self.assertEqual(out.dtype, torch.bfloat16) + + # ------------------------------------------------------------------ + # Validation errors + # ------------------------------------------------------------------ + + def test_hq_not_divisible_by_hkv_rejected(self): + """H_Q not divisible by H_KV raises RuntimeError.""" + D = 64 + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(1, 5, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 3, 64, D, dtype=torch.bfloat16, device="cuda") + k_p, k_n = _compress(k, boundaries, rotation) + v_p, v_n = _compress(k, boundaries, rotation) + with self.assertRaises(RuntimeError): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, rotation) + + def test_causal_lq_ne_lkv_rejected(self): + """is_causal=True with L_Q != L_KV raises RuntimeError.""" + D = 64 + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + k_p, k_n = _compress(k, boundaries, rotation) + v_p, v_n = _compress(k, boundaries, rotation) + with self.assertRaises(RuntimeError): + self.tq4_sdpa( + q, + k_p, + k_n, + v_p, + v_n, + centroids, + rotation, + is_causal=True, + ) + + def test_non_pow2_head_dim_rejected(self): + """Non-power-of-2 HEAD_DIM raises RuntimeError.""" + D = 80 + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, rotation = centroids.cuda(), rotation.cuda() + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k_p = torch.zeros(1, 4, 64, D // 2, dtype=torch.uint8, device="cuda") + k_n = torch.zeros(1, 4, 64, 1, dtype=torch.bfloat16, device="cuda") + with self.assertRaises(RuntimeError): + self.tq4_sdpa(q, k_p, k_n, k_p, k_n, centroids, rotation) + + def test_per_head_mask_rejected(self): + """Per-head masks (H>1) should be rejected since the kernel broadcasts.""" + D = 64 + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + k_p, k_n = _compress(k, boundaries, rotation) + v_p, v_n = _compress(k, boundaries, rotation) + # H=4 instead of H=1 + mask = torch.ones(1, 4, 1, 64, dtype=torch.bool, device="cuda") + with self.assertRaises(RuntimeError): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, rotation, mask) + + def test_mask_shape_mismatch_rejected(self): + """Mask with wrong B/Lq/Lkv dims raises RuntimeError.""" + D = 64 + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + centroids, boundaries, rotation = ( + centroids.cuda(), + boundaries.cuda(), + rotation.cuda(), + ) + q = torch.randn(1, 4, 1, D, dtype=torch.bfloat16, device="cuda") + k = torch.randn(1, 4, 64, D, dtype=torch.bfloat16, device="cuda") + k_p, k_n = _compress(k, boundaries, rotation) + v_p, v_n = _compress(k, boundaries, rotation) + # Wrong Lkv: 32 instead of 64 + mask = torch.ones(1, 1, 1, 32, dtype=torch.bool, device="cuda") + with self.assertRaises(RuntimeError): + self.tq4_sdpa(q, k_p, k_n, v_p, v_n, centroids, rotation, mask) + + # ------------------------------------------------------------------ + # Full path: TurboQuantKVCache + tq4_sdpa + # ------------------------------------------------------------------ + + def test_kv_cache_plus_sdpa(self): + """TurboQuantKVCache.update() -> tq4_sdpa matches reference SDPA.""" + D, H_Q, H_KV, MAX_SEQ = 128, 8, 2, 64 + torch.manual_seed(42) + centroids, boundaries, rotation = _make_codebook_and_rotation(D) + + cache = TurboQuantKVCache(H_KV, D, MAX_SEQ).cuda() + + # Prefill 16 tokens + k_pf = torch.randn(1, H_KV, 16, D, dtype=torch.bfloat16, device="cuda") + v_pf = torch.randn(1, H_KV, 16, D, dtype=torch.bfloat16, device="cuda") + pos_pf = torch.arange(16, device="cuda") + k_packed, k_norms, v_packed, v_norms = cache.update(pos_pf, k_pf, v_pf) + + # Decode query + q = torch.randn(1, H_Q, 1, D, dtype=torch.bfloat16, device="cuda") + mask = torch.zeros(1, 1, 1, MAX_SEQ, dtype=torch.bool, device="cuda") + mask[:, :, :, :16] = True + + out = self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + cache.centroids, + cache.rotation, + mask, + ) + + # Reference: use test's own decompress + standard SDPA + k_dec = _decompress( + k_packed[:, :, :16], + k_norms[:, :, :16], + centroids.cuda(), + rotation.cuda(), + ) + v_dec = _decompress( + v_packed[:, :, :16], + v_norms[:, :, :16], + centroids.cuda(), + rotation.cuda(), + ) + k_dec = k_dec.repeat_interleave(H_Q // H_KV, dim=1) + v_dec = v_dec.repeat_interleave(H_Q // H_KV, dim=1) + ref = F.scaled_dot_product_attention(q.float(), k_dec.float(), v_dec.float()) + + cos = _cosine_sim(out, ref) + self.assertGreater(cos, 0.95, f"Cosine {cos:.4f}") + self.assertFalse(torch.isnan(out).any()) + + def test_kv_cache_decode_accumulates(self): + """Decode tokens accumulate in cache and affect attention output.""" + D, H_Q, H_KV, MAX_SEQ = 64, 4, 2, 32 + torch.manual_seed(42) + + cache = TurboQuantKVCache(H_KV, D, MAX_SEQ).cuda() + + # Insert tokens one at a time, keep update() return values + for i in range(8): + k = torch.randn(1, H_KV, 1, D, dtype=torch.bfloat16, device="cuda") + v = torch.randn(1, H_KV, 1, D, dtype=torch.bfloat16, device="cuda") + k_packed, k_norms, v_packed, v_norms = cache.update( + torch.tensor([i], device="cuda"), k, v + ) + + q = torch.randn(1, H_Q, 1, D, dtype=torch.bfloat16, device="cuda") + mask = torch.zeros(1, 1, 1, MAX_SEQ, dtype=torch.bool, device="cuda") + mask[:, :, :, :8] = True + + out = self.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + cache.centroids, + cache.rotation, + mask, + ) + + self.assertFalse(torch.isnan(out).any()) + self.assertGreater(out.abs().max().item(), 0, "Output is all zeros") + + # ------------------------------------------------------------------ + # Export through CUDA backend + # ------------------------------------------------------------------ + + def test_export_cuda(self): + """Export tq4_sdpa through CudaPartitioner, verify .pte is produced.""" + with tempfile.TemporaryDirectory() as tmpdir: + pte_path, _ = _export_tq4_attn(tmpdir) + self.assertTrue(os.path.exists(pte_path)) + self.assertGreater(os.path.getsize(pte_path), 0) + + def test_e2e_cpp_runner(self): + """Export once, run executor_runner with multiple inputs, compare.""" + if not os.path.exists(RUNNER_PATH): + self.skipTest( + f"executor_runner not found at {RUNNER_PATH}. " + "Build with: cmake --build cmake-out --target executor_runner" + ) + + D, H_Q, SEQ = 128, 4, 64 + e2e_seeds = [0, 7, 42] + + with tempfile.TemporaryDirectory() as tmpdir: + export_dir = os.path.join(tmpdir, "export") + pte_path, model = _export_tq4_attn(export_dir) + ptd_path = os.path.join(export_dir, "aoti_cuda_blob.ptd") + + for seed in e2e_seeds: + with self.subTest(seed=seed): + inputs = _make_tq4_inputs(seed, H_Q, D, SEQ) + + with torch.no_grad(): + ref = model(*inputs) + + run_dir = os.path.join(tmpdir, f"run_seed{seed}") + os.makedirs(run_dir) + + input_files = [] + for i, tensor in enumerate(inputs): + path = os.path.join(run_dir, f"{i}.bin") + _save_tensor(tensor, path) + input_files.append(path) + + output_base = os.path.join(run_dir, "output") + result = _run_cpp_runner( + RUNNER_PATH, + pte_path, + ptd_path, + input_files, + output_base, + ) + self.assertEqual( + result.returncode, + 0, + f"seed={seed}: executor_runner failed:\n{result.stderr}", + ) + + cpp_out = _load_output( + f"{output_base}-0.bin", + (1, H_Q, 2, D), + torch.bfloat16, + ) + + cos = _cosine_sim(cpp_out, ref.cpu()) + self.assertGreater( + cos, + 0.99, + f"seed={seed}: cosine {cos:.4f}", + ) + + +# --------------------------------------------------------------------------- +# Export + runner helpers +# --------------------------------------------------------------------------- + +EXECUTORCH_ROOT = os.path.normpath(os.path.join(os.path.dirname(__file__), "../../..")) +RUNNER_PATH = os.path.join(EXECUTORCH_ROOT, "cmake-out", "executor_runner") + + +class _TQ4AttnModule(torch.nn.Module): + """Minimal module wrapping tq4_sdpa for export testing.""" + + def __init__(self, head_dim, h_q, h_kv, max_seq): + super().__init__() + centroids, boundaries, rotation = _make_codebook_and_rotation(head_dim) + self.register_buffer("centroids", centroids) + self.register_buffer("rotation", rotation) + + # Pre-populate with compressed random data so outputs are non-zero + k = torch.randn(1, h_kv, max_seq, head_dim) + v = torch.randn(1, h_kv, max_seq, head_dim) + k_packed, k_norms = _compress(k, boundaries, rotation) + v_packed, v_norms = _compress(v, boundaries, rotation) + self.register_buffer("k_packed", k_packed) + self.register_buffer("k_norms", k_norms.to(torch.bfloat16)) + self.register_buffer("v_packed", v_packed) + self.register_buffer("v_norms", v_norms.to(torch.bfloat16)) + + def forward(self, query, attn_mask): + return tq4_sdpa( + query, + self.k_packed, + self.k_norms, + self.v_packed, + self.v_norms, + self.centroids, + self.rotation, + attn_mask, + ) + + +def _export_tq4_attn(output_dir): + """Export a _TQ4AttnModule to .pte + .ptd. Returns (pte_path, model).""" + D, H_Q, H_KV, SEQ = 128, 4, 2, 64 + + torch.manual_seed(42) + model = _TQ4AttnModule(D, H_Q, H_KV, SEQ).to("cuda").eval() + inputs = _make_tq4_inputs(42, H_Q, D, SEQ) + + with torch.no_grad(): + ep = export(model, inputs, strict=True) + + os.makedirs(output_dir, exist_ok=True) + + specs = [CudaBackend.generate_method_name_compile_spec("forward")] + et_prog = to_edge_transform_and_lower( + ep, + partitioner=[CudaPartitioner(specs)], + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + pte_path = os.path.join(output_dir, "tq4_sdpa.pte") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + + if hasattr(et_program, "_tensor_data") and et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + + return pte_path, model + + +def _make_tq4_inputs(seed, h_q, head_dim, max_seq, device="cuda"): + torch.manual_seed(seed) + q = torch.randn(1, h_q, 2, head_dim, dtype=torch.bfloat16, device=device) + mask = torch.ones(1, 1, 2, max_seq, dtype=torch.bool, device=device) + return (q, mask) + + +def _save_tensor(t, path): + t_cpu = t.cpu().contiguous() + with open(path, "wb") as f: + f.write(bytes(t_cpu.untyped_storage())) + + +def _load_output(path, shape, dtype): + data = np.fromfile(path, dtype=np.uint8) + return torch.frombuffer(bytearray(data), dtype=dtype).reshape(shape) + + +def _run_cpp_runner(runner_path, pte_path, ptd_path, input_files, output_base): + cmd = [ + runner_path, + f"--model_path={pte_path}", + f"--data_path={ptd_path}", + f"--inputs={','.join(input_files)}", + f"--output_file={output_base}", + ] + return subprocess.run(cmd, capture_output=True, text=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/cuda/triton/kernels/__init__.py b/backends/cuda/triton/kernels/__init__.py index 1000f3e9773..e7af2bdaf84 100644 --- a/backends/cuda/triton/kernels/__init__.py +++ b/backends/cuda/triton/kernels/__init__.py @@ -22,3 +22,10 @@ __all__.append("chunk_gated_delta_rule") except ImportError: pass + +try: + from executorch.backends.cuda.triton.kernels.tq4_sdpa import tq4_sdpa # noqa: F401 + + __all__.append("tq4_sdpa") +except ImportError: + pass diff --git a/backends/cuda/triton/kernels/tq4_sdpa.py b/backends/cuda/triton/kernels/tq4_sdpa.py new file mode 100644 index 00000000000..a4748540342 --- /dev/null +++ b/backends/cuda/triton/kernels/tq4_sdpa.py @@ -0,0 +1,756 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# TQ4 fused Flash Attention kernel with Pack GQA optimization. +# +# Decompression logic adapted from turboquant-vllm v1.4.0 +# (Alberto-Codes/turboquant-vllm, Apache 2.0). +# Pack GQA and kernel structure adapted from sdpa.py in this directory. +# +# Compatible with: turboquant-vllm 1.4.0 +# +# Reference: arXiv 2504.19874 — "TurboQuant: Online Vector Quantization +# with Near-optimal Distortion Rate" (ICLR 2026). + +""" +Fused TQ4 SDPA: attention over nibble-packed compressed K/V cache. + +Both K and V tiles are decompressed inline from uint8 nibble-packed indices +in the attention inner loop. The full decompressed cache is never materialized. +Q is pre-rotated by Pi^T, output is post-rotated by Pi outside the kernel. + +Pack GQA (from FlashAttention) folds multiple Q heads sharing one KV head +into the M dimension, loading K/V only once per KV head group. This gives +up to NUM_GROUPS x reduction in K/V HBM traffic for decode. +""" + +import math +from typing import Optional + +import torch +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _should_pack_gqa(L_q: int, num_groups: int, block_m: int) -> bool: + """Decide whether to use Pack GQA based on tile utilization. + + Pack GQA folds multiple Q heads into the M dimension so they share + the same K/V loads. This helps when seqlen_q is small relative to + BLOCK_M (e.g., decode with seqlen_q=1). + + Heuristic from FlashAttention (hopper/heuristics.h, should_pack_gqa). + """ + if num_groups <= 1: + return False + + def round_up(a, b): + return ((a + b - 1) // b) * b + + nopack_eff = L_q / round_up(L_q, block_m) + pack_eff = (L_q * num_groups) / round_up(L_q * num_groups, block_m) + return nopack_eff < 0.9 * pack_eff + + +# --------------------------------------------------------------------------- +# Kernel body +# --------------------------------------------------------------------------- + + +@triton.jit +def _tq4_sdpa_fwd_kernel_body( + Q_ptr, + KP_ptr, + KN_ptr, + VP_ptr, + VN_ptr, + LUT_hi_ptr, + LUT_lo_ptr, + Mask_ptr, + O_ptr, + B, + H_grid, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kpb, + stride_kph, + stride_kpn, + stride_kpd, + stride_knb, + stride_knh, + stride_knn, + stride_vpb, + stride_vph, + stride_vpn, + stride_vpd, + stride_vnb, + stride_vnh, + stride_vnn, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + HALF_D: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_bh = tl.program_id(axis=1) + b = pid_bh // H_grid + h_grid = pid_bh % H_grid + + offs_packed = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, HEAD_DIM) + offs_d_half = tl.arange(0, HALF_D) + + if PACK_GQA: + seq_pos = offs_packed // NUM_GROUPS + h_within = offs_packed % NUM_GROUPS + h_q_rows = h_grid * NUM_GROUPS + h_within + h_kv = h_grid + row_valid = seq_pos < Lq + + q_ptrs = Q_ptr + ( + b * stride_qb + + h_q_rows[:, None] * stride_qh + + seq_pos[:, None] * stride_qm + + offs_d[None, :] * stride_qd + ) + else: + seq_pos = offs_packed + h_kv = h_grid // NUM_GROUPS + row_valid = offs_packed < Lq + + q_ptrs = Q_ptr + ( + b * stride_qb + + h_grid * stride_qh + + offs_packed[:, None] * stride_qm + + offs_d[None, :] * stride_qd + ) + + q = tl.load(q_ptrs, mask=row_valid[:, None], other=0.0).to(tl.bfloat16) + + m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # Prescale for exp2-based softmax (single PTX instruction) + qk_scale = sm_scale * 1.44269504 + + # TQ4 K/V base pointers (uniform: single KV head) + kp_base = KP_ptr + b * stride_kpb + h_kv * stride_kph + kn_base = KN_ptr + b * stride_knb + h_kv * stride_knh + vp_base = VP_ptr + b * stride_vpb + h_kv * stride_vph + vn_base = VN_ptr + b * stride_vnb + h_kv * stride_vnh + + offs_n_init = tl.arange(0, BLOCK_N) + + for start_n in tl.range(0, Lk, BLOCK_N): + offs_n = start_n + offs_n_init + kv_valid = offs_n < Lk + + # -- K decompression (LUT, no norm multiply on [N,D] tile) -- + kp_ptrs = ( + kp_base + offs_n[:, None] * stride_kpn + offs_d_half[None, :] * stride_kpd + ) + k_packed_data = tl.load(kp_ptrs, mask=kv_valid[:, None], other=0).to(tl.int32) + k = tl.join( + tl.load(LUT_hi_ptr + k_packed_data), + tl.load(LUT_lo_ptr + k_packed_data), + ).reshape(BLOCK_N, HEAD_DIM) + + # Q @ K^T with norm factored out: Q @ (C*n)^T = (Q @ C^T) * n^T + kn = tl.load(kn_base + offs_n * stride_knn, mask=kv_valid, other=0.0) + qk = (tl.dot(q, tl.trans(k)) * qk_scale * kn[None, :]).to(tl.float32) + + if HAS_MASK: + mask_ptrs = Mask_ptr + ( + b * stride_mb + + seq_pos[:, None] * stride_mq + + offs_n[None, :] * stride_mk + ) + mn_mask = row_valid[:, None] & kv_valid[None, :] + mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) + qk = tl.where(mask_block, qk, float("-inf")) + + if IS_CAUSAL: + causal = offs_n[None, :] > seq_pos[:, None] + qk = tl.where(causal, float("-inf"), qk) + + qk = tl.where(kv_valid[None, :], qk, float("-inf")) + + # NaN-safe online softmax (exp2) + m_ij = tl.max(qk, 1) + m_new = tl.maximum(m_i, m_ij) + # Guard against -inf - (-inf) = NaN when all positions are masked + safe_alpha = tl.where(m_new > -float("inf"), m_i - m_new, 0.0) + alpha = tl.math.exp2(safe_alpha) + safe_p = tl.where( + m_new[:, None] > -float("inf"), qk - m_new[:, None], -float("inf") + ) + p = tl.math.exp2(safe_p) + l_ij = tl.sum(p, 1) + + # -- V decompression (LUT, norm factored into P) -- + vp_ptrs = ( + vp_base + offs_n[:, None] * stride_vpn + offs_d_half[None, :] * stride_vpd + ) + v_packed_data = tl.load(vp_ptrs, mask=kv_valid[:, None], other=0).to(tl.int32) + v = tl.join( + tl.load(LUT_hi_ptr + v_packed_data), + tl.load(LUT_lo_ptr + v_packed_data), + ).reshape(BLOCK_N, HEAD_DIM) + + # P @ (C*n) = (P*n) @ C — factor norm into P instead of V + vn = tl.load(vn_base + offs_n * stride_vnn, mask=kv_valid, other=0.0) + p_scaled = (p * vn[None, :]).to(tl.bfloat16) + acc = (acc * alpha[:, None] + tl.dot(p_scaled, v)).to(tl.float32) + l_i = (l_i * alpha + l_ij).to(tl.float32) + m_i = m_new + + inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0) + acc = acc * inv_l_i[:, None] + + # O store: scattered when PACK_GQA, uniform otherwise + if PACK_GQA: + o_ptrs = O_ptr + ( + b * stride_ob + + h_q_rows[:, None] * stride_oh + + seq_pos[:, None] * stride_om + + offs_d[None, :] * stride_od + ) + else: + o_ptrs = O_ptr + ( + b * stride_ob + + h_grid * stride_oh + + offs_packed[:, None] * stride_om + + offs_d[None, :] * stride_od + ) + tl.store(o_ptrs, acc.to(tl.bfloat16), mask=row_valid[:, None]) + + +# --------------------------------------------------------------------------- +# Autotuned kernel wrappers (M64 and M32) +# --------------------------------------------------------------------------- + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8, num_stages=3), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], +) +@triton.jit +def _tq4_sdpa_fwd_kernel_m64( + Q_ptr, + KP_ptr, + KN_ptr, + VP_ptr, + VN_ptr, + LUT_hi_ptr, + LUT_lo_ptr, + Mask_ptr, + O_ptr, + B, + H_grid, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kpb, + stride_kph, + stride_kpn, + stride_kpd, + stride_knb, + stride_knh, + stride_knn, + stride_vpb, + stride_vph, + stride_vpn, + stride_vpd, + stride_vnb, + stride_vnh, + stride_vnn, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + HALF_D: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _tq4_sdpa_fwd_kernel_body( + Q_ptr, + KP_ptr, + KN_ptr, + VP_ptr, + VN_ptr, + LUT_hi_ptr, + LUT_lo_ptr, + Mask_ptr, + O_ptr, + B, + H_grid, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kpb, + stride_kph, + stride_kpn, + stride_kpd, + stride_knb, + stride_knh, + stride_knn, + stride_vpb, + stride_vph, + stride_vpn, + stride_vpd, + stride_vnb, + stride_vnh, + stride_vnn, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + HALF_D=HALF_D, + NUM_GROUPS=NUM_GROUPS, + PACK_GQA=PACK_GQA, + ) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32}, num_warps=4, num_stages=2), + ], + key=["Lq", "Lk", "HEAD_DIM", "HAS_MASK", "IS_CAUSAL", "NUM_GROUPS", "PACK_GQA"], +) +@triton.jit +def _tq4_sdpa_fwd_kernel_m32( + Q_ptr, + KP_ptr, + KN_ptr, + VP_ptr, + VN_ptr, + LUT_hi_ptr, + LUT_lo_ptr, + Mask_ptr, + O_ptr, + B, + H_grid, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kpb, + stride_kph, + stride_kpn, + stride_kpd, + stride_knb, + stride_knh, + stride_knn, + stride_vpb, + stride_vph, + stride_vpn, + stride_vpd, + stride_vnb, + stride_vnh, + stride_vnn, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale: tl.float32, + HAS_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + HEAD_DIM: tl.constexpr, + HALF_D: tl.constexpr, + NUM_GROUPS: tl.constexpr, + PACK_GQA: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _tq4_sdpa_fwd_kernel_body( + Q_ptr, + KP_ptr, + KN_ptr, + VP_ptr, + VN_ptr, + LUT_hi_ptr, + LUT_lo_ptr, + Mask_ptr, + O_ptr, + B, + H_grid, + Lq, + Lk, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kpb, + stride_kph, + stride_kpn, + stride_kpd, + stride_knb, + stride_knh, + stride_knn, + stride_vpb, + stride_vph, + stride_vpn, + stride_vpd, + stride_vnb, + stride_vnh, + stride_vnn, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=IS_CAUSAL, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM, + HALF_D=HALF_D, + NUM_GROUPS=NUM_GROUPS, + PACK_GQA=PACK_GQA, + ) + + +# --------------------------------------------------------------------------- +# Host-side launcher +# --------------------------------------------------------------------------- + + +def _launch_tq4_kernel( + q_rot, + k_packed, + k_norms, + v_packed, + v_norms, + lut_hi, + lut_lo, + mask_ptr, + out_rot, + B, + H_Q, + H_KV, + L_Q, + L_KV, + D, + sm_scale, + HAS_MASK, + stride_mb, + stride_mq, + stride_mk, + is_causal, + num_groups, + pack_gqa, +): + HALF_D = D // 2 + + if pack_gqa: + H_grid = H_KV + Lq_packed = L_Q * num_groups + else: + H_grid = H_Q + Lq_packed = L_Q + + def grid(meta): + return (triton.cdiv(Lq_packed, meta["BLOCK_M"]), B * H_grid) + + total_ctas_m64 = ((Lq_packed + 63) // 64) * (B * H_grid) + threshold = 4 * 84 + kernel = ( + _tq4_sdpa_fwd_kernel_m32 + if total_ctas_m64 < threshold + else _tq4_sdpa_fwd_kernel_m64 + ) + + wrap_triton(kernel)[grid]( + q_rot, + k_packed, + k_norms, + v_packed, + v_norms, + lut_hi, + lut_lo, + mask_ptr if HAS_MASK else 0, + out_rot, + B, + H_grid, + L_Q, + L_KV, + *q_rot.stride(), + *k_packed.stride(), + *k_norms.stride(), + *v_packed.stride(), + *v_norms.stride(), + *out_rot.stride(), + stride_mb, + stride_mq, + stride_mk, + sm_scale, + HAS_MASK=HAS_MASK, + IS_CAUSAL=is_causal, + HEAD_DIM=D, + HALF_D=HALF_D, + NUM_GROUPS=num_groups, + PACK_GQA=pack_gqa, + ) + + +# --------------------------------------------------------------------------- +# @triton_op wrapper +# --------------------------------------------------------------------------- + + +def _validate_tq4_inputs(query, k_packed, v_packed): + """Validate tensor shapes, dtypes, and device for tq4_sdpa.""" + B, H_Q, N_Q, D = query.shape + B_kp, H_KV, N_KV, HALF_D = k_packed.shape + + if not query.is_cuda: + raise RuntimeError("query must be a CUDA tensor") + if query.dtype != torch.bfloat16: + raise RuntimeError(f"query must be bfloat16, got {query.dtype}") + if query.dim() != 4: + raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D") + if k_packed.dim() != 4 or v_packed.dim() != 4: + raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]") + if k_packed.dtype != torch.uint8 or v_packed.dtype != torch.uint8: + raise RuntimeError("k_packed and v_packed must be uint8") + if B_kp != B: + raise RuntimeError( + f"Batch dim mismatch: query has B={B}, k_packed has B={B_kp}" + ) + if H_Q % H_KV != 0: + raise RuntimeError( + f"H_Q must be a multiple of H_KV for GQA head mapping, " + f"got H_Q={H_Q}, H_KV={H_KV}" + ) + if HALF_D * 2 != D: + raise RuntimeError( + f"k_packed last dim ({HALF_D}) * 2 must equal query head_dim ({D})" + ) + if D & (D - 1) != 0: + raise RuntimeError( + f"HEAD_DIM must be a power of 2, got {D}. " + "Non-power-of-2 head dims are not supported." + ) + + +def _validate_tq4_mask(attn_mask, B, N_Q, N_KV): + """Validate attention mask for tq4_sdpa.""" + if attn_mask is None: + return + if attn_mask.dtype != torch.bool: + raise RuntimeError( + f"attn_mask must be bool, got {attn_mask.dtype}. " + "Additive float masks are not supported." + ) + if not attn_mask.is_cuda: + raise RuntimeError("attn_mask must be a CUDA tensor") + if attn_mask.shape[1] != 1: + raise RuntimeError( + f"attn_mask head dimension must be 1 (broadcast over heads); " + f"per-head masks are not supported. " + f"Got attn_mask.shape={attn_mask.shape}" + ) + if ( + attn_mask.shape[0] != B + or attn_mask.shape[2] != N_Q + or attn_mask.shape[3] != N_KV + ): + raise RuntimeError( + f"attn_mask shape mismatch: expected " + f"[B={B}, 1, L_Q={N_Q}, L_KV={N_KV}], " + f"got {attn_mask.shape}" + ) + + +@triton_op("triton::tq4_sdpa", mutates_args={}) +def tq4_sdpa( + query: torch.Tensor, + k_packed: torch.Tensor, + k_norms: torch.Tensor, + v_packed: torch.Tensor, + v_norms: torch.Tensor, + centroids: torch.Tensor, + rotation: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, +) -> torch.Tensor: + """Fused TQ4 SDPA over nibble-packed compressed K/V cache. + + Decompresses K/V per tile in the attention inner loop. The full + decompressed cache is never materialized (3.8x memory savings). + + H_Q must be a multiple of H_KV (GQA/MQA). HEAD_DIM must be a + power of 2. The kernel maps Q heads to KV heads internally via + Pack GQA when beneficial. + + Args: + query: [B, H_Q, L_Q, D] bf16 + k_packed: [B, H_KV, L_KV, D//2] uint8 nibble-packed key indices + k_norms: [B, H_KV, L_KV, 1] key vector norms (float or bf16) + v_packed: [B, H_KV, L_KV, D//2] uint8 nibble-packed value indices + v_norms: [B, H_KV, L_KV, 1] value vector norms (float or bf16) + centroids: [16] fp32 Lloyd-Max codebook + rotation: [D, D] orthogonal rotation matrix + attn_mask: Optional [B, 1, L_Q, L_KV] bool mask + is_causal: apply causal masking (requires L_Q == L_KV) + + Returns: + [B, H_Q, L_Q, D] bf16 attention output + """ + _validate_tq4_inputs(query, k_packed, v_packed) + + B, H_Q, N_Q, D = query.shape + _, H_KV, N_KV, HALF_D = k_packed.shape + + _validate_tq4_mask(attn_mask, B, N_Q, N_KV) + + sm_scale = 1.0 / math.sqrt(D) + num_groups = H_Q // H_KV + + # Build [256] bf16 lookup tables from [16] centroids. + # In the export path, inductor fuses this into the compiled graph. + all_bytes = torch.arange(256, device=centroids.device) + lut_hi = centroids[(all_bytes >> 4).long()].to(query.dtype).contiguous() + lut_lo = centroids[(all_bytes & 0x0F).long()].to(query.dtype).contiguous() + + # Reshape norms: [B, H, S, 1] -> [B, H, S] + k_n = k_norms.reshape(B, H_KV, N_KV).contiguous() + v_n = v_norms.reshape(B, H_KV, N_KV).contiguous() + + # Pre-rotate Q: Q_rot = Q @ Pi^T (bf16 — TQ4 error dominates) + q_rot = torch.matmul(query, rotation.T.to(query.dtype)).contiguous() + + out_rot = torch.empty_like(query) + + HAS_MASK = attn_mask is not None + if is_causal and N_Q != N_KV: + raise RuntimeError( + f"is_causal requires L_Q == L_KV, got L_Q={N_Q}, L_KV={N_KV}. " + "For decode (L_Q < L_KV), use an explicit bool mask instead." + ) + if HAS_MASK: + mask_ptr = attn_mask + stride_mb = attn_mask.stride(0) + stride_mq = attn_mask.stride(2) + stride_mk = attn_mask.stride(3) + else: + mask_ptr = 0 + stride_mb = 0 + stride_mq = 0 + stride_mk = 0 + + # Pack GQA decision + total_ctas_m64 = ((N_Q * num_groups + 63) // 64) * (B * H_KV) + block_m = 32 if total_ctas_m64 < 4 * 84 else 64 + pack_gqa = _should_pack_gqa(N_Q, num_groups, block_m) + + _launch_tq4_kernel( + q_rot, + k_packed, + k_n, + v_packed, + v_n, + lut_hi, + lut_lo, + mask_ptr, + out_rot, + B, + H_Q, + H_KV, + N_Q, + N_KV, + D, + sm_scale, + HAS_MASK, + stride_mb, + stride_mq, + stride_mk, + is_causal, + num_groups, + pack_gqa, + ) + + # Post-rotate: convert from rotated space back to original space + return torch.matmul(out_rot, rotation.to(query.dtype)) + + +@tq4_sdpa.register_fake +def _tq4_sdpa_fake( + query: torch.Tensor, + k_packed: torch.Tensor, + k_norms: torch.Tensor, + v_packed: torch.Tensor, + v_norms: torch.Tensor, + centroids: torch.Tensor, + rotation: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + is_causal: bool = False, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index 39e68bdd102..947f02acee6 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -50,6 +50,16 @@ python export.py \ | `--qembedding` | (none) | Embedding quantization: `8w` | | `--hqq` | off | Use HQQ scale-only optimization for expert quantization (slower, better accuracy) | | `--prequantized` | (none) | Path to prequantized bundle directory (skips quantization) | +| `--turboquant` | off | Enable TurboQuant TQ4 KV cache compression (3.8x cache savings) | + +### TurboQuant KV Cache Compression + +The `--turboquant` flag enables [TurboQuant](https://arxiv.org/abs/2504.19874) +KV cache compression (3.8x savings) on the 10 full-attention layers. + +```bash +python export.py --prequantized qwen35_moe_int4_hqq --turboquant +``` ### Prequantized Export diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 01125dc75e3..7437bc5f461 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -110,7 +110,7 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096): ".kv_cache.", ".conv_state", ".recurrent_state", - ".mask", + ".cache_positions", ".inv_freq", ) expected_missing = {k for k in missing if any(p in k for p in runtime_prefixes)} @@ -339,13 +339,37 @@ def _materialize_buffers(model, config): ) rope.inv_freq = inv_freq - # Recompute causal masks for full attention layers + # Recompute cache_positions for full attention layers for layer in model.layers: - if hasattr(layer.attn, "mask"): - mask = torch.tril( - torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool) + if hasattr(layer.attn, "cache_positions"): + layer.attn.cache_positions = torch.arange( + config.max_seq_len, dtype=torch.long ) - layer.attn.register_buffer("mask", mask) + + +def _apply_turboquant(model, config): + """Replace KV caches in full-attention layers with TurboQuantKVCache. + + Runs after _materialize_buffers so the new TQ4 buffers are created + with correct dtypes and not affected by any blanket cast in _quantize. + """ + from executorch.extension.llm.modules.turboquant import TurboQuantKVCache + + count = 0 + for layer in model.layers: + if layer.layer_type != "full_attention": + continue + old_cache = layer.attn.kv_cache + _, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape + layer.attn.kv_cache = TurboQuantKVCache( + n_heads, + head_dim, + max_seq_len, + ) + layer.attn.turboquant = True + count += 1 + + print(f"Replaced {count} KV caches with TurboQuantKVCache (TQ4)") # --------------------------------------------------------------------------- @@ -480,6 +504,11 @@ def main(): "containing model.safetensors and config.json. " "Skips quantization; --model-dir is not needed.", ) + parser.add_argument( + "--turboquant", + action="store_true", + help="Enable TurboQuant TQ4 KV cache compression (3.8x cache savings).", + ) args = parser.parse_args() if not args.prequantized and not args.model_dir: @@ -493,6 +522,10 @@ def main(): model, config = load_and_quantize(args) _materialize_buffers(model, config) + + if args.turboquant: + _apply_turboquant(model, config) + export_and_lower(model, config, args) diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index 1e2abba2b9f..d9f127d9ed1 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -229,11 +229,12 @@ def __init__(self, config): ) self.kv_cache = KVCache(self.n_kv_heads, self.head_dim, config.max_seq_len) + self.turboquant = False - mask = torch.tril( - torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool) + self.register_buffer( + "cache_positions", + torch.arange(config.max_seq_len, dtype=torch.long), ) - self.register_buffer("mask", mask) def forward(self, x, input_pos): B, T, _ = x.size() @@ -264,15 +265,30 @@ def forward(self, x, input_pos): k = k.to(dtype).transpose(1, 2) v = v.transpose(1, 2) - # KV cache - k, v = self.kv_cache.update(input_pos, k, v) - - # SDPA with GQA — kernel maps Q heads to KV heads internally - attn_mask = self.mask[input_pos].unsqueeze(0).unsqueeze(0) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, enable_gqa=True + attn_mask = ( + (self.cache_positions[None, :] <= input_pos[:, None]) + .unsqueeze(0) + .unsqueeze(0) ) + if self.turboquant: + k_packed, k_norms, v_packed, v_norms = self.kv_cache.update(input_pos, k, v) + y = torch.ops.triton.tq4_sdpa( + q, + k_packed, + k_norms, + v_packed, + v_norms, + self.kv_cache.centroids, + self.kv_cache.rotation, + attn_mask, + ) + else: + k, v = self.kv_cache.update(input_pos, k, v) + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, enable_gqa=True + ) + y = y.transpose(1, 2).contiguous().view(B, T, -1) # Output gate diff --git a/examples/models/qwen3_5_moe/test_turboquant.py b/examples/models/qwen3_5_moe/test_turboquant.py new file mode 100644 index 00000000000..53474dc2515 --- /dev/null +++ b/examples/models/qwen3_5_moe/test_turboquant.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test TurboQuant KV cache on a tiny Qwen 3.5 MoE model. + +Creates a tiny model (no downloads needed), quantizes weights, applies +TurboQuant KV cache compression, exports with torch.export, and verifies +the exported program produces correct output. + +Requires CUDA (fused_moe and tq4_sdpa Triton kernels). + +Usage: + python -m pytest examples/models/qwen3_5_moe/test_turboquant.py -v +""" + +import unittest + +import torch + +from executorch.examples.models.qwen3_5_moe.export import ( + _apply_turboquant, + _materialize_buffers, + _quantize, +) +from executorch.examples.models.qwen3_5_moe.model import Qwen35MoE, Qwen35MoEConfig +from torch.export import Dim, export + + +TINY_CONFIG = Qwen35MoEConfig( + vocab_size=256, + hidden_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_kv_heads=2, + head_dim=64, + partial_rotary_factor=0.25, + linear_num_key_heads=2, + linear_num_value_heads=2, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_conv_kernel_dim=4, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=128, + shared_expert_intermediate_size=128, + full_attention_interval=2, + rms_norm_eps=1e-6, + rope_theta=10_000.0, + max_seq_len=64, +) + + +def _make_model(turboquant=False): + """Create a tiny quantized model, optionally with TurboQuant KV cache.""" + import executorch.backends.cuda.triton.kernels # noqa: F401 + + torch.manual_seed(42) + model = Qwen35MoE(TINY_CONFIG) + model.to(dtype=torch.bfloat16) + for p in model.parameters(): + if p.device.type != "meta": + p.data.normal_(0, 0.02) + model.eval() + + class Args: + qlinear = "4w" + qembedding = None + qlinear_group_size = 32 + hqq = False + + _quantize(model, TINY_CONFIG, Args()) + _materialize_buffers(model, TINY_CONFIG) + + if turboquant: + _apply_turboquant(model, TINY_CONFIG) + + model.to("cuda") + return model + + +def _greedy_decode(forward_fn, prompt, num_tokens): + """Greedy decode using forward(tokens, input_pos) signature.""" + for i, tok_id in enumerate(prompt): + logits = forward_fn( + torch.tensor([[tok_id]], dtype=torch.long, device="cuda"), + torch.tensor([i], dtype=torch.long, device="cuda"), + ) + generated = [] + next_tok = logits[0, -1].argmax().item() + generated.append(next_tok) + for i in range(num_tokens - 1): + logits = forward_fn( + torch.tensor([[next_tok]], dtype=torch.long, device="cuda"), + torch.tensor([len(prompt) + i], dtype=torch.long, device="cuda"), + ) + next_tok = logits[0, -1].argmax().item() + generated.append(next_tok) + return generated + + +class TestTurboQuant(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + self.skipTest("CUDA required") + + def test_eager_quality(self): + """TurboQuant model has >99% cosine similarity to baseline.""" + model_base = _make_model(turboquant=False) + model_tq = _make_model(turboquant=True) + + tokens = torch.tensor([[1, 2, 3, 4]], dtype=torch.long, device="cuda") + input_pos = torch.arange(4, dtype=torch.long, device="cuda") + + with torch.no_grad(): + logits_base = model_base(tokens, input_pos) + logits_tq = model_tq(tokens, input_pos) + + cos = torch.nn.functional.cosine_similarity( + logits_base.reshape(1, -1).float(), + logits_tq.reshape(1, -1).float(), + ).item() + self.assertGreater(cos, 0.99, f"Cosine {cos:.4f}") + + def test_eager_decode_quality(self): + """TurboQuant decode logits stay close to baseline across steps.""" + model_base = _make_model(turboquant=False) + model_tq = _make_model(turboquant=True) + + prompt = [1, 2, 3] + for i, tok_id in enumerate(prompt): + tok = torch.tensor([[tok_id]], dtype=torch.long, device="cuda") + pos = torch.tensor([i], dtype=torch.long, device="cuda") + with torch.no_grad(): + logits_base = model_base(tok, pos) + logits_tq = model_tq(tok, pos) + + # Check cosine similarity of logits after prefill + cos = torch.nn.functional.cosine_similarity( + logits_base.reshape(1, -1).float(), + logits_tq.reshape(1, -1).float(), + ).item() + self.assertGreater(cos, 0.99, f"Prefill cosine {cos:.4f}") + + def test_export_matches_eager(self): + """Exported TQ model produces same greedy tokens as eager.""" + model = _make_model(turboquant=True) + + def eager_fn(tok, pos): + with torch.no_grad(): + return model(tok, pos) + + eager_tokens = _greedy_decode(eager_fn, [1, 2, 3], 5) + + # Export + seq_dim = Dim("seq_len", min=1, max=TINY_CONFIG.max_seq_len - 1) + with torch.no_grad(): + ep = export( + model, + ( + torch.tensor([[0, 1]], dtype=torch.long, device="cuda"), + torch.tensor([0, 1], dtype=torch.long, device="cuda"), + ), + dynamic_shapes=({1: seq_dim}, {0: seq_dim}), + strict=True, + ) + ep_mod = ep.module() + + def exported_fn(tok, pos): + with torch.no_grad(): + return ep_mod(tok, pos) + + exported_tokens = _greedy_decode(exported_fn, [1, 2, 3], 5) + self.assertEqual(eager_tokens, exported_tokens) + + def test_kv_cache_state_matters(self): + """Different prefills produce different continuations.""" + model_a = _make_model(turboquant=True) + model_b = _make_model(turboquant=True) + + def fn_a(tok, pos): + with torch.no_grad(): + return model_a(tok, pos) + + def fn_b(tok, pos): + with torch.no_grad(): + return model_b(tok, pos) + + tokens_a = _greedy_decode(fn_a, [1, 2, 3], 3) + tokens_b = _greedy_decode(fn_b, [10, 20, 30], 3) + self.assertNotEqual(tokens_a, tokens_b) + + def test_replacement_count(self): + """_apply_turboquant replaces exactly the full-attention layers.""" + model = _make_model(turboquant=True) + from executorch.extension.llm.modules.turboquant import TurboQuantKVCache + + tq_count = sum( + 1 + for layer in model.layers + if isinstance(getattr(layer.attn, "kv_cache", None), TurboQuantKVCache) + ) + fa_count = sum(1 for lt in TINY_CONFIG.layer_types if lt == "full_attention") + self.assertEqual(tq_count, fa_count) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/modules/test/test_turboquant_kv_cache.py b/extension/llm/modules/test/test_turboquant_kv_cache.py new file mode 100644 index 00000000000..4173fbbcd58 --- /dev/null +++ b/extension/llm/modules/test/test_turboquant_kv_cache.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for TurboQuantKVCache module. + +Verifies KV cache update/roundtrip quality, nibble packing correctness, +codebook/rotation properties, and torch.export compatibility. + +No CUDA required — all tests run on CPU. + +Usage: + python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py -v +""" + +import unittest + +import torch + +from executorch.extension.llm.modules.turboquant import TurboQuantKVCache +from executorch.extension.llm.modules.turboquant.codebook import ( + generate_rotation_matrix, + solve_lloyd_max, +) +from torch.export import Dim, export + +HEAD_DIM = 128 +N_HEADS = 2 +MAX_SEQ_LEN = 32 +BITS = 4 + + +def _roundtrip_cosine(cache, x, input_pos): + """Update cache and measure roundtrip quality via cosine similarity. + + Reconstructs from the returned packed/norms using the test's own + decompress logic (independent of cache internals). + """ + k_p, k_n, _, _ = cache.update(input_pos, x, x) + T = x.shape[2] + + # Decompress using centroids + rotation from the cache (public buffers) + flat = k_p[:, :, input_pos].reshape(-1, HEAD_DIM // 2) + flat_norms = k_n[:, :, input_pos].reshape(-1, 1).float() + high = (flat >> 4).long() + low = (flat & 0x0F).long() + indices = torch.stack([high, low], dim=-1).reshape(-1, HEAD_DIM) + reconstructed = cache.centroids.float()[indices] + unrotated = reconstructed @ cache.rotation_T.float().T + recon = (unrotated * flat_norms).reshape(1, N_HEADS, T, HEAD_DIM) + + return torch.nn.functional.cosine_similarity( + x.reshape(-1, HEAD_DIM).float(), + recon.reshape(-1, HEAD_DIM).float(), + ).mean() + + +class TestCacheUpdate(unittest.TestCase): + """update() produces high-quality compressed representation.""" + + def test_roundtrip_quality(self): + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + x = torch.randn(1, N_HEADS, 10, HEAD_DIM) + pos = torch.arange(10) + cos = _roundtrip_cosine(cache, x, pos) + self.assertGreater(cos.item(), 0.99) + + def test_output_shapes(self): + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + x = torch.randn(1, N_HEADS, 5, HEAD_DIM) + k_p, k_n, v_p, v_n = cache.update(torch.arange(5), x, x) + + self.assertEqual(k_p.shape, (1, N_HEADS, MAX_SEQ_LEN, HEAD_DIM // 2)) + self.assertEqual(k_p.dtype, torch.uint8) + self.assertEqual(k_n.shape, (1, N_HEADS, MAX_SEQ_LEN, 1)) + self.assertEqual(v_p.shape, k_p.shape) + self.assertEqual(v_n.shape, k_n.shape) + + def test_bf16_input(self): + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + x = torch.randn(1, N_HEADS, 5, HEAD_DIM, dtype=torch.bfloat16) + pos = torch.arange(5) + cos = _roundtrip_cosine(cache, x, pos) + self.assertGreater(cos.item(), 0.99) + + def test_state_accumulates(self): + """Writing to different positions preserves earlier data.""" + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + + x0 = torch.randn(1, N_HEADS, 4, HEAD_DIM) + cache.update(torch.arange(4), x0, x0) + + x1 = torch.randn(1, N_HEADS, 4, HEAD_DIM) + k_p, k_n, _, _ = cache.update(torch.arange(4, 8), x1, x1) + + # Positions 0-3 should still have x0's data + cos = _roundtrip_cosine(cache, x0, torch.arange(4)) + self.assertGreater(cos.item(), 0.99) + + # Positions 8+ should be zero (never written) + self.assertEqual(k_p[:, :, 8:].abs().max().item(), 0) + + def test_head_dim_256(self): + """Qwen 3.5 MoE config.""" + cache = TurboQuantKVCache(2, 256, 64) + x = torch.randn(1, 2, 5, 256) + k_p, k_n, v_p, v_n = cache.update(torch.arange(5), x, x) + self.assertEqual(k_p.shape, (1, 2, 64, 128)) + self.assertEqual(k_p.dtype, torch.uint8) + + +class TestNibblePacking(unittest.TestCase): + """uint8 nibble pack/unpack is bit-exact.""" + + def test_roundtrip_all_index_pairs(self): + all_pairs = torch.stack( + torch.meshgrid(torch.arange(16), torch.arange(16), indexing="ij"), + dim=-1, + ).reshape(-1, 2) + + packed = (all_pairs[:, 0].to(torch.uint8) << 4) | all_pairs[:, 1].to( + torch.uint8 + ) + + high = (packed >> 4).long() + low = (packed & 0x0F).long() + + self.assertTrue(torch.equal(high, all_pairs[:, 0].long())) + self.assertTrue(torch.equal(low, all_pairs[:, 1].long())) + + +class TestTorchExport(unittest.TestCase): + """TurboQuantKVCache survives torch.export(strict=True).""" + + def test_export_standalone(self): + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + seq_dim = Dim("seq", min=1, max=MAX_SEQ_LEN - 1) + + with torch.no_grad(): + ep = export( + cache, + args=( + torch.arange(2), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + ), + dynamic_shapes={ + "input_pos": {0: seq_dim}, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + }, + strict=True, + ) + + mod = ep.module() + k = torch.randn(1, N_HEADS, 3, HEAD_DIM) + v = torch.randn(1, N_HEADS, 3, HEAD_DIM) + k_p, k_n, v_p, v_n = mod(torch.arange(3), k, v) + + self.assertEqual(k_p.shape, (1, N_HEADS, MAX_SEQ_LEN, HEAD_DIM // 2)) + self.assertEqual(k_n.shape, (1, N_HEADS, MAX_SEQ_LEN, 1)) + + def test_exported_state_accumulates(self): + cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) + seq_dim = Dim("seq", min=1, max=MAX_SEQ_LEN - 1) + + with torch.no_grad(): + ep = export( + cache, + args=( + torch.arange(2), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + ), + dynamic_shapes={ + "input_pos": {0: seq_dim}, + "k_val": {2: seq_dim}, + "v_val": {2: seq_dim}, + }, + strict=True, + ) + + mod = ep.module() + + # Write positions 0-1 + k0 = torch.randn(1, N_HEADS, 2, HEAD_DIM) + mod(torch.arange(2), k0, torch.randn(1, N_HEADS, 2, HEAD_DIM)) + + # Write positions 2-3, get full cache back + k_p, k_n, _, _ = mod( + torch.arange(2, 4), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + torch.randn(1, N_HEADS, 2, HEAD_DIM), + ) + + # Positions 0-1 should be non-zero (k0's data preserved) + self.assertGreater(k_p[:, :, :2].abs().max().item(), 0) + + # Positions 4+ should be zero (never written) + self.assertEqual(k_p[:, :, 4:].abs().max().item(), 0) + + +class TestCodebook(unittest.TestCase): + """Lloyd-Max codebook and rotation matrix correctness.""" + + def test_centroids_sorted(self): + centroids, boundaries = solve_lloyd_max(HEAD_DIM, BITS) + self.assertEqual(centroids.shape, (16,)) + self.assertEqual(boundaries.shape, (15,)) + self.assertTrue(torch.all(centroids[1:] > centroids[:-1])) + self.assertTrue(torch.all(boundaries[1:] > boundaries[:-1])) + + def test_centroids_symmetric(self): + """Codebook should be roughly symmetric around zero.""" + centroids, _ = solve_lloyd_max(HEAD_DIM, BITS) + self.assertAlmostEqual(centroids.mean().item(), 0.0, places=4) + + def test_boundaries_between_centroids(self): + centroids, boundaries = solve_lloyd_max(HEAD_DIM, BITS) + for i in range(len(boundaries)): + self.assertGreater(boundaries[i].item(), centroids[i].item()) + self.assertLess(boundaries[i].item(), centroids[i + 1].item()) + + def test_codebook_deterministic(self): + c1, b1 = solve_lloyd_max(HEAD_DIM, BITS) + c2, b2 = solve_lloyd_max(HEAD_DIM, BITS) + self.assertTrue(torch.equal(c1, c2)) + self.assertTrue(torch.equal(b1, b2)) + + def test_codebook_varies_with_dim(self): + c64, _ = solve_lloyd_max(64, BITS) + c256, _ = solve_lloyd_max(256, BITS) + self.assertFalse(torch.allclose(c64, c256)) + + def test_rotation_orthogonal(self): + R = generate_rotation_matrix(HEAD_DIM) + self.assertEqual(R.shape, (HEAD_DIM, HEAD_DIM)) + eye = R @ R.T + self.assertTrue(torch.allclose(eye, torch.eye(HEAD_DIM), atol=1e-5)) + + def test_rotation_deterministic(self): + R1 = generate_rotation_matrix(HEAD_DIM, seed=42) + R2 = generate_rotation_matrix(HEAD_DIM, seed=42) + self.assertTrue(torch.equal(R1, R2)) + + def test_rotation_varies_with_seed(self): + R1 = generate_rotation_matrix(HEAD_DIM, seed=42) + R2 = generate_rotation_matrix(HEAD_DIM, seed=99) + self.assertFalse(torch.equal(R1, R2)) + + +class TestEdgeCases(unittest.TestCase): + + def test_odd_head_dim_raises(self): + with self.assertRaises(ValueError): + TurboQuantKVCache(2, 127, 32) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/modules/turboquant/__init__.py b/extension/llm/modules/turboquant/__init__.py new file mode 100644 index 00000000000..d82b441287d --- /dev/null +++ b/extension/llm/modules/turboquant/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.extension.llm.modules.turboquant.kv_cache import TurboQuantKVCache + +__all__ = [ + "TurboQuantKVCache", +] diff --git a/extension/llm/modules/turboquant/codebook.py b/extension/llm/modules/turboquant/codebook.py new file mode 100644 index 00000000000..ddcb3a5cf3f --- /dev/null +++ b/extension/llm/modules/turboquant/codebook.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Lloyd-Max codebook solver and rotation matrix generator adapted from +# turboquant-vllm (Alberto-Codes/turboquant-vllm, Apache 2.0). +# +# Reference: arXiv 2504.19874 — "TurboQuant: Online Vector Quantization +# with Near-optimal Distortion Rate" (ICLR 2026). + +""" +Lloyd-Max optimal scalar quantizer and random rotation matrix for TurboQuant. + +After random orthogonal rotation, each coordinate of a unit-norm vector +follows a distribution concentrated near zero (Beta for exact, Gaussian +N(0, 1/d) for the d >= 64 approximation). The Lloyd-Max algorithm finds +the optimal centroids minimizing MSE for this distribution. + +Results are cached so multi-layer models pay the scipy cost only once. +""" + +import math +from functools import lru_cache + +import torch + + +def solve_lloyd_max( + d: int, + bits: int, + *, + max_iter: int = 200, + tol: float = 1e-10, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute optimal Lloyd-Max centroids and boundaries. + + Uses the Gaussian approximation N(0, 1/d) for the rotated coordinate + distribution, which is accurate for d >= 64. + + Args: + d: Vector dimension. + bits: Quantization bits (produces 2^bits centroids). + max_iter: Maximum Lloyd-Max iterations. + tol: Convergence tolerance on centroid movement. + + Returns: + (centroids, boundaries) as 1-D float32 tensors. + centroids has length 2^bits, boundaries has length 2^bits - 1. + """ + return _solve_lloyd_max_cached(d, bits, max_iter, tol) + + +@lru_cache(maxsize=32) +def _solve_lloyd_max_cached( + d: int, bits: int, max_iter: int, tol: float +) -> tuple[torch.Tensor, torch.Tensor]: + try: + from scipy import integrate + from scipy.stats import norm + except ImportError: + raise ImportError( + "scipy is required for TurboQuant codebook computation. " + "Install it with: pip install scipy" + ) + + n_levels = 1 << bits + sigma = 1.0 / math.sqrt(d) + lo, hi = -3.0 * sigma, 3.0 * sigma + + def pdf(x): + return float(norm.pdf(x, loc=0.0, scale=sigma)) + + centroids = [lo + (hi - lo) * (i + 0.5) / n_levels for i in range(n_levels)] + + for _ in range(max_iter): + boundaries = [ + (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1) + ] + edges = [lo] + boundaries + [hi] + new_centroids = [] + for i in range(n_levels): + a, b = edges[i], edges[i + 1] + if b - a < 1e-15: + new_centroids.append((a + b) / 2.0) + continue + numer, _ = integrate.quad(lambda x: x * pdf(x), a, b) + denom, _ = integrate.quad(pdf, a, b) + if denom < 1e-15: + new_centroids.append((a + b) / 2.0) + else: + new_centroids.append(numer / denom) + + max_shift = max(abs(new_centroids[i] - centroids[i]) for i in range(n_levels)) + centroids = new_centroids + if max_shift < tol: + break + + boundaries_final = [ + (centroids[i] + centroids[i + 1]) / 2.0 for i in range(n_levels - 1) + ] + return ( + torch.tensor(centroids, dtype=torch.float32), + torch.tensor(boundaries_final, dtype=torch.float32), + ) + + +def generate_rotation_matrix(dim: int, seed: int = 42) -> torch.Tensor: + """Generate a Haar-distributed random orthogonal matrix via QR. + + Args: + dim: Matrix dimension (d x d). + seed: Random seed for reproducibility. + + Returns: + Orthogonal matrix of shape (dim, dim) in float32 on CPU. + """ + gen = torch.Generator(device="cpu").manual_seed(seed) + gaussian = torch.randn(dim, dim, generator=gen, device="cpu", dtype=torch.float32) + q, r = torch.linalg.qr(gaussian) + diag_sign = torch.sign(torch.diag(r)) + diag_sign[diag_sign == 0] = 1.0 + return q * diag_sign.unsqueeze(0) diff --git a/extension/llm/modules/turboquant/kv_cache.py b/extension/llm/modules/turboquant/kv_cache.py new file mode 100644 index 00000000000..12c01721a15 --- /dev/null +++ b/extension/llm/modules/turboquant/kv_cache.py @@ -0,0 +1,166 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +TurboQuant KV cache compression for torch.export(strict=True). + +Compresses KV cache to TQ4 nibble-packed format (3.8x memory savings) +using the TurboQuant algorithm (arXiv 2504.19874, ICLR 2026). The +codebook and rotation matrix are precomputed at init time; the forward +path is pure PyTorch ops. + +Paired with the fused ``triton::tq4_sdpa`` kernel, attention runs +directly on compressed data — the full decompressed cache is never +materialized. + +Usage:: + + from executorch.extension.llm.modules.turboquant import TurboQuantKVCache + + # Replace KV cache in attention module, then set the flag: + attn.kv_cache = TurboQuantKVCache(n_heads, head_dim, max_seq_len) + attn.turboquant = True + +See ``examples/models/qwen3_5_moe/export.py`` for a full integration +example with model-specific replacement logic. +""" + +import torch +import torch.nn as nn + +from executorch.extension.llm.modules.turboquant.codebook import ( + generate_rotation_matrix, + solve_lloyd_max, +) + + +class TurboQuantKVCache(nn.Module): + """KV cache with TQ4 compression. + + Stores K/V as nibble-packed uint8 indices (2 indices per byte) plus + bf16 per-vector norms. The ``update()`` method compresses incoming + K/V and returns the compressed cache buffers for use with the fused + ``triton::tq4_sdpa`` kernel. A ``_decompress()`` method is provided + for testing. + + Args: + n_heads: Number of KV heads. + head_dim: Dimension per head (must be even). + max_seq_len: Maximum sequence length (cache is pre-allocated). + bits: Quantization bits per coordinate (must be 4). + seed: Random seed for the rotation matrix. + + Note: + Batch size is fixed to 1 (standard for ExecuTorch inference). + Input tensors must have shape ``(1, H, T, D)``. + """ + + def __init__(self, n_heads, head_dim, max_seq_len, bits=4, seed=42): + super().__init__() + if bits != 4: + raise ValueError( + f"Only 4-bit quantization is supported (nibble packing + " + f"16-entry codebook). Got bits={bits}." + ) + if head_dim % 2 != 0: + raise ValueError(f"head_dim must be even, got {head_dim}") + + self.n_heads = n_heads + self.head_dim = head_dim + self.half_dim = head_dim // 2 + + centroids, boundaries = solve_lloyd_max(head_dim, bits) + rotation = generate_rotation_matrix(head_dim, seed=seed) + + self.register_buffer("centroids", centroids) + self.register_buffer("boundaries", boundaries.to(torch.bfloat16)) + self.register_buffer("rotation", rotation) + self.register_buffer("rotation_T", rotation.T.to(torch.bfloat16).contiguous()) + + # Compressed cache buffers + self.register_buffer( + "k_packed", + torch.zeros(1, n_heads, max_seq_len, self.half_dim, dtype=torch.uint8), + ) + self.register_buffer( + "k_norms", + torch.zeros(1, n_heads, max_seq_len, 1, dtype=torch.bfloat16), + ) + self.register_buffer( + "v_packed", + torch.zeros(1, n_heads, max_seq_len, self.half_dim, dtype=torch.uint8), + ) + self.register_buffer( + "v_norms", + torch.zeros(1, n_heads, max_seq_len, 1, dtype=torch.bfloat16), + ) + + def _compress(self, x): + """Compress ``(1, H, T, D)`` tensor to nibble-packed uint8 + bf16 norms. + + All ops are torch.export-compatible: norm, matmul, bucketize, bitwise. + Stays in bf16 throughout — TQ4 quantization error dominates bf16 rounding. + """ + orig_shape = x.shape + flat = x.reshape(-1, self.head_dim).to(self.rotation_T.dtype) + + norms = torch.linalg.vector_norm(flat, dim=-1, keepdim=True) + normalized = flat / (norms + 1e-10) + rotated = normalized @ self.rotation_T + indices = torch.bucketize(rotated, self.boundaries) + + idx_u8 = indices.to(torch.uint8) + packed = (idx_u8[:, 0::2] << 4) | idx_u8[:, 1::2] + + return ( + packed.reshape(*orig_shape[:-1], self.half_dim), + norms.reshape(*orig_shape[:-1], 1).to(torch.bfloat16), + ) + + def _decompress(self, packed, norms): + """Decompress nibble-packed uint8 + norms back to float tensor. + + Provided for testing — the fused ``tq4_sdpa`` kernel decompresses + per-tile in the attention inner loop, never calling this method. + """ + orig_batch_shape = packed.shape[:-1] + flat_packed = packed.reshape(-1, self.half_dim) + flat_norms = norms.reshape(-1, 1).float() + + high = (flat_packed >> 4).long() + low = (flat_packed & 0x0F).long() + indices = torch.stack([high, low], dim=-1).reshape(-1, self.head_dim) + + reconstructed = self.centroids.float()[indices] + unrotated = reconstructed @ self.rotation_T.float().T + scaled = unrotated * flat_norms + + return scaled.reshape(*orig_batch_shape, self.head_dim) + + def forward(self, input_pos, k_val, v_val): + return self.update(input_pos, k_val, v_val) + + def update(self, input_pos, k_val, v_val): + """Compress and store K/V, return compressed cache buffers. + + Args: + input_pos: ``(T,)`` position indices. + k_val: ``(1, H, T, D)`` key tensor (batch size must be 1). + v_val: ``(1, H, T, D)`` value tensor (batch size must be 1). + + Returns: + Tuple of ``(k_packed, k_norms, v_packed, v_norms)`` — the full + compressed cache (all positions, not just the new tokens). + """ + k_packed, k_norms = self._compress(k_val) + v_packed, v_norms = self._compress(v_val) + + self.k_packed[:, :, input_pos] = k_packed + self.k_norms[:, :, input_pos] = k_norms + self.v_packed[:, :, input_pos] = v_packed + self.v_norms[:, :, input_pos] = v_norms + + return self.k_packed, self.k_norms, self.v_packed, self.v_norms