Skip to content

TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE#18687

Merged
mergennachin merged 1 commit intomainfrom
mergennachin/turboquant
Apr 5, 2026
Merged

TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE#18687
mergennachin merged 1 commit intomainfrom
mergennachin/turboquant

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Apr 3, 2026

TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE

Add TurboQuant (arXiv 2504.19874) KV cache compression to the CUDA
backend, reducing KV cache memory 3.8x by storing nibble-packed uint8
indices + bf16 norms instead of bf16 tensors. A fused Triton SDPA
kernel decompresses K/V per-tile in the attention inner loop so the
full cache is never materialized.

Components:

  • backends/cuda/triton/kernels/tq4_sdpa.py: Fused TQ4 Flash Attention
    kernel with Pack GQA optimization (adapted from sdpa.py structure),
    precomputed [256]-entry bf16 LUT for zero-bit-op centroid gather,
    and norm factoring (norms multiplied on [M,N] QK/P matrices instead
    of [N,D] K/V tiles). NaN-safe softmax guards for sparse masks.
    Registered as @triton_op for torch.export + CUDA backend lowering.

  • extension/llm/modules/turboquant/: TurboQuantKVCache nn.Module with
    bf16 compression path and self-contained Lloyd-Max codebook solver
    (no external dependencies; scipy used lazily for codebook init only).

  • examples/models/qwen3_5_moe/: --turboquant flag in export.py, branch
    in FullAttention.forward() between standard SDPA and tq4_sdpa.

  • backends/aoti/: Added aoti_torch_dtype_uint8 shim and Byte ScalarType
    to slim headers (required for uint8 KV cache tensors in C++ runtime).

Performance (A100, Qwen 3.5 MoE, B=1, GQA 16:2, D=256, seq=4096):
TQ4 SDPA kernel: 0.66ms (was 3.74ms before optimizations)
Baseline bf16: 0.45ms (1.5x overhead for 3.8x memory savings)
Full AOTI path: 0.79ms (inductor fuses compress ops)

The full Qwen 3.5 MoE model with TurboQuant KV cache compression runs at 75% (78->60)of baseline decode speed with 3.8x memory savings (relative to normal KVcache). End-to-end validated through the C++ runner. But full attention kvcache is small part of Qwen3.5 MoE because 3/4 of it is recurrent states. In 200K context length, turboquant saves about 3GB (4GB -> 1GB) of kvcache.

Copilot AI review requested due to automatic review settings April 3, 2026 17:04
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18687

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 3 Unrelated Failures

As of commit 94973a3 with merge base 300e368 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 3, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 3, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from a292381 to 969767b Compare April 3, 2026 17:06
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Claude finished @mergennachin's task in 6m 45s —— View job


Review of PR #18687 — TurboQuant TQ4 KV Cache Compression

  • Read all changed files
  • Review Triton kernel (tq4_sdpa.py)
  • Review KV cache module (turboquant/kv_cache.py)
  • Review model/export changes
  • Review AOTI shims
  • Review tests
  • Review CI changes

Overall this is a well-structured PR. The kernel, cache module, model integration, tests, and CI changes are all consistent. I have a few observations and suggestions below.


Triton Kernel — backends/cuda/triton/kernels/tq4_sdpa.py

Correctness: The nibble decompression logic using tl.join + .reshape correctly inverts the packing (even << 4) | oddtl.join produces [BLOCK_N, HALF_D_PAD, 2] which reshapes to interleaved [even_0, odd_0, even_1, odd_1, ...], matching the original index order. The online softmax implementation follows the standard two-pass-free algorithm.

  1. Dummy mask pointer when HAS_MASK=False (line 302): When no mask is provided, k_packed is passed as the dummy Mask_ptr. While the pointer is never dereferenced (the if HAS_MASK branch is compiled out), this is a bit surprising. The existing sdpa.py uses 0 as the dummy value. This is fine functionally but slightly unusual.

  2. IS_CAUSAL is always False (line 295): The wrapper hardcodes is_causal = False and relies on explicit masks from the caller. This is fine for the current use case where FullAttention always passes a mask, but the kernel itself supports IS_CAUSAL — worth noting in case someone tries to call tq4_sdpa without a mask for causal attention. The docstring could mention that causal behavior requires passing an explicit mask.

  3. Pre/post rotation in the wrapper (lines 284, 341): The Q @ Pi^T pre-rotation and out @ Pi post-rotation are done outside the kernel as full torch.matmul in float32. This is correct and avoids kernel complexity, but these are O(B·H·L·D²) matmuls. For long-sequence prefill, this cost may be non-trivial relative to the attention itself. Not a blocker — just something to watch in profiling. Investigate rotation cost →

  4. No q_rot contiguity check: The kernel passes *q_rot.stride() which could be wrong if q_rot is non-contiguous. In practice torch.matmul returns a contiguous tensor, so this is fine, but an explicit .contiguous() call like k_n = k_norms.reshape(...).contiguous() (line 280) would be defensive. Minor.

  5. Autotune search space (lines 48-55): The configs are reasonable. The filter if not (w == 8 and BM < 64) makes sense since 8 warps on small tiles wastes occupancy.


KV Cache Module — extension/llm/modules/turboquant/kv_cache.py

  1. Hard dependency on turboquant_vllm (line 34): The import is at module top-level, so importing TurboQuantKVCache anywhere will fail if turboquant-vllm is not installed. The __init__.py imports it unconditionally. The kernel's __init__.py wraps the import in try/except ImportError, but the module-level __init__.py at extension/llm/modules/turboquant/__init__.py does not. This means from executorch.extension.llm.modules.turboquant import TurboQuantKVCache will raise an ImportError instead of a clear error message. Consider either:

    • Adding a try/except in __init__.py with a helpful message, or
    • Lazy-importing turboquant_vllm in TurboQuantKVCache.__init__ (the solve/rotation calls only happen at init time, not in forward).

    Fix import →

  2. _compress epsilon (line 98): norms + 1e-10 uses a very small epsilon for division safety. For bf16 inputs cast to float32, this is fine. The same pattern appears in the test helper.

  3. forward delegates to update (line 130-131): Clean pattern for nn.Module — allows both cache(pos, k, v) and cache.update(pos, k, v).

  4. Batch dimension hardcoded to 1 (lines 72-87): Cache buffers are allocated with batch=1. This is consistent with the existing KVCache pattern in the model. Fine for the current use case.


Model/Export Changes — examples/models/qwen3_5_moe/

  1. turboquant attribute on FullAttention (model.py line 232, line 270-283): The self.turboquant = False flag set in __init__ and checked in forward is a clean way to branch between standard SDPA and TQ4. The flag is set to True by _apply_turboquant in export.py.

  2. _apply_turboquant placement (export.py lines 351-374): Running after _materialize_buffers is correct — the TQ4 buffers (uint8, fp32) need to be created with their real dtypes, not the bf16 blanket cast.

  3. Full cache returned every call (kv_cache.py line 153): update() returns the entire self.k_packed etc. on every call, meaning the full cache buffer is part of the graph output. This is consistent with how the standard KVCache.update() works (model.py line 196-198), so it integrates cleanly with the existing pattern.


AOTI Shims — backends/aoti/common_shims_slim.*

  1. Added aoti_torch_dtype_uint8: Returns 0 (ScalarType::Byte). This is the correct value for PyTorch's uint8 dtype. The declaration and implementation are consistent. This was needed because the TQ4 packed buffers are uint8 tensors that flow through the AOTI-generated wrapper code.

Tests

  1. test_tq4_sdpa.py — Comprehensive kernel tests covering MHA, GQA (2x, 4x, 8x), various head dims (64, 128, 256), decode and prefill, bool masks, batch > 1, and an export+C++ runner e2e test. Cosine similarity threshold of 0.95 is reasonable for 4-bit quantized attention. The structure follows the existing test_triton_sdpa.py pattern.

  2. test_turboquant_kv_cache.py — Good coverage: numerical equivalence with the library, nibble packing roundtrip, bf16 tolerance, torch.export compatibility, and state accumulation across calls. Runs on CPU.

  3. test_turboquant.py — Integration tests with a tiny model: eager quality, greedy decode matching, export matching eager, cache state matters, replacement count. These are the most important tests for validating the full pipeline. Note these require CUDA and turboquant-vllm.

  4. Test module in example dir (line 24-28 of test_turboquant.py): Imports _apply_turboquant, _materialize_buffers, _quantize from export.py. These are underscore-prefixed (private) functions being imported in tests, which is normal for test code.


CI Changes

  1. cuda.yml: Adds pip install "turboquant-vllm==1.4.0" and adds the new test file to the pytest command. The pinned version matches the compatibility note in the code.

  2. export_model_artifact.sh: Adds EXTRA_PIP="turboquant-vllm==1.4.0" and --turboquant flag to the export command. This means the CI export artifact will now include TQ4 compressed KV caches.


Summary

The PR is well-organized and the code is correct. The Triton kernel faithfully implements the TQ4 decompression-in-attention-loop pattern from the paper. The integration into the existing model/export pipeline is clean.

Key items to consider:

  • Item 6 (import): The unconditional turboquant_vllm import at module level will cause ImportError for users who don't have the package installed but try to import the module for any reason. A lazy import or better error message would improve the experience.
  • Item 3 (rotation cost): The pre/post rotation matmuls are O(B·H·L·D²) outside the kernel. For long prefills this could be significant — worth confirming via profiling.
  • Everything else is minor or informational.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds TurboQuant TQ4 KV-cache compression and a fused Triton SDPA path to reduce KV-cache memory for Qwen 3.5 MoE, with export/e2e coverage and CI wiring.

Changes:

  • Introduces TurboQuantKVCache that compresses K/V into nibble-packed uint8 indices + fp32 norms and exposes an update API for attention.
  • Adds a fused Triton op triton::tq4_sdpa that performs inline K/V decompression per tile during attention.
  • Integrates a --turboquant export flag + Qwen 3.5 MoE model switch, plus CI/test updates and an AOTI uint8 dtype shim.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
extension/llm/modules/turboquant/kv_cache.py Implements KV cache compression/decompression and stateful updates for export.
extension/llm/modules/turboquant/init.py Exposes the TurboQuant module API.
extension/llm/modules/test/test_turboquant_kv_cache.py CPU unit tests for compress/decompress, packing, and torch.export.
examples/models/qwen3_5_moe/test_turboquant.py CUDA E2E test validating model quality + export behavior with TurboQuant enabled.
examples/models/qwen3_5_moe/model.py Adds a TurboQuant attention path that calls torch.ops.triton.tq4_sdpa.
examples/models/qwen3_5_moe/export.py Adds --turboquant and replaces full-attention KV caches with TurboQuantKVCache.
backends/cuda/triton/kernels/tq4_sdpa.py New fused Triton SDPA kernel/op for attention over compressed KV cache.
backends/cuda/triton/kernels/init.py Registers the new Triton op for import side-effects.
backends/cuda/tests/test_tq4_sdpa.py CUDA kernel correctness + export/runner coverage for tq4_sdpa.
backends/aoti/common_shims_slim.h Adds missing aoti_torch_dtype_uint8() declaration.
backends/aoti/common_shims_slim.cpp Implements aoti_torch_dtype_uint8() to fix runner symbol resolution.
.github/workflows/cuda.yml Installs turboquant-vllm and runs the new TurboQuant tests in CUDA CI.
.ci/scripts/export_model_artifact.sh Ensures Qwen 3.5 MoE export installs turboquant-vllm and enables --turboquant.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +21 to +27
from executorch.extension.llm.modules.turboquant import (
TurboQuantKVCache,
replace_kv_cache_with_turboquant,
)

# After model construction, before torch.export:
replace_kv_cache_with_turboquant(model, kv_cache_class=KVCache)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring references replace_kv_cache_with_turboquant, but that symbol isn't implemented/exported anywhere in the repo. This makes the usage example incorrect and will confuse users. Either add the helper function and export it from extension/llm/modules/turboquant/__init__.py, or update the docstring to reflect the actual integration path (e.g., _apply_turboquant in the Qwen export script).

Suggested change
from executorch.extension.llm.modules.turboquant import (
TurboQuantKVCache,
replace_kv_cache_with_turboquant,
)
# After model construction, before torch.export:
replace_kv_cache_with_turboquant(model, kv_cache_class=KVCache)
from executorch.extension.llm.modules.turboquant.kv_cache import (
TurboQuantKVCache,
)
# Instantiate and wire this module into the model wherever a KV cache
# implementation is selected before torch.export.
kv_cache = TurboQuantKVCache(
n_heads=n_heads,
head_dim=head_dim,
max_seq_len=max_seq_len,
)

Copilot uses AI. Check for mistakes.
"""
B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tq4_sdpa assumes GQA-style head mapping with H_Q // H_KV and will break (division by zero or incorrect mapping) when H_Q < H_KV or H_Q % H_KV != 0. Please validate these constraints in the Python wrapper (and ideally assert k_packed.size(-1) * 2 == D as well) to fail fast with a clear error.

Suggested change
if H_KV <= 0:
raise ValueError(
f"tq4_sdpa requires k_packed to have a positive KV head dimension, got H_KV={H_KV}."
)
if H_Q < H_KV:
raise ValueError(
f"tq4_sdpa requires GQA-compatible head mapping with H_Q >= H_KV, got H_Q={H_Q}, H_KV={H_KV}."
)
if H_Q % H_KV != 0:
raise ValueError(
f"tq4_sdpa requires GQA-compatible head mapping with H_Q % H_KV == 0, got H_Q={H_Q}, H_KV={H_KV}."
)
if k_packed.size(-1) * 2 != D:
raise ValueError(
f"tq4_sdpa requires k_packed.size(-1) * 2 == query.size(-1), got "
f"k_packed.size(-1)={k_packed.size(-1)} and D={D}."
)
if v_packed.size(-1) != HALF_D:
raise ValueError(
f"tq4_sdpa requires v_packed.size(-1) == k_packed.size(-1), got "
f"v_packed.size(-1)={v_packed.size(-1)} and k_packed.size(-1)={HALF_D}."
)

Copilot uses AI. Check for mistakes.
Comment on lines +291 to +305
# Masking: use explicit mask when provided, never auto-causal.
# The caller (FullAttention) always provides an explicit bool mask
# that handles both prefill (lower-triangular) and decode (row mask).
HAS_MASK = attn_mask is not None
is_causal = False
if HAS_MASK:
Mask_ptr = attn_mask
stride_mb = attn_mask.stride(0)
stride_mq = attn_mask.stride(2)
stride_mk = attn_mask.stride(3)
else:
Mask_ptr = k_packed # dummy, won't be accessed
stride_mb = 0
stride_mq = 0
stride_mk = 0
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Triton kernel ignores the mask's head dimension (it only offsets by batch), but tq4_sdpa accepts any attn_mask. If a caller passes a per-head mask (shape [B, H_Q, L_Q, L_KV]), the kernel will effectively read only the first head's mask. Consider validating attn_mask.size(1) == 1 (or extending the kernel to handle per-head masks) to avoid silent misbehavior.

Copilot uses AI. Check for mistakes.
Comment on lines +136 to +147
@classmethod
def setUpClass(cls):
_skip_if_no_cuda()
cls.tq4_sdpa = _import_tq4_sdpa()

def _run_test(self, B, H_q, H_kv, Lq, Lk, D, attn_mask=None, min_cosine=0.95):
torch.manual_seed(42)
centroids, boundaries, rotation = _make_codebook_and_rotation(D)
centroids = centroids.cuda()
boundaries = boundaries.cuda()
rotation = rotation.cuda()

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These CUDA kernel tests rely on turboquant-vllm for reference codebook/rotation generation. If the dependency isn't installed, the tests will error rather than skip. Please add an ImportError guard (e.g., in setUpClass) that skips the suite when turboquant_vllm is unavailable, consistent with other optional-dependency tests in the repo.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 3, 2026 17:57 — with GitHub Actions Inactive
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from 969767b to 7fb707a Compare April 3, 2026 18:36
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Look at the revision based on previous feedback. Also review the PR more holistically

@claude
Copy link
Copy Markdown

claude bot commented Apr 3, 2026

Claude finished @mergennachin's task in 8m 19s —— View job


Review of PR #18687 — TurboQuant TQ4 KV Cache Compression (revision check + holistic review)

  • Check what changed since the initial review
  • Read all changed files in current state
  • Assess which previous review items were addressed
  • Holistic review of the full PR
  • Post final review

Status of Previous Review Feedback

There is only one commit (7fb707a) on this branch — no revisions have been pushed since the initial Claude and Copilot reviews. Looking at the code, it appears several items from both reviews were already present in the original commit (i.e. the initial code already anticipated some concerns). Here's the status:

Item Status Notes
bits != 4 validation (Copilot) Already handled kv_cache.py:62-66 raises ValueError if bits != 4
GQA validation in kernel (Copilot, Claude #2) Already handled tq4_sdpa.py:282-290 validates H_Q % H_KV != 0 and HALF_D * 2 != D
Import guard for tests (Copilot) Already handled test_turboquant_kv_cache.py:23-28 has try/except ImportError with HAS_TURBOQUANT flag; test_tq4_sdpa.py:140-145 has setUpClass skip; test_turboquant.py:107-112 has setUp skip
Unconditional turboquant_vllm import (Claude #6, Copilot) Not addressed extension/llm/modules/turboquant/__init__.py:7 still imports unconditionally
Module docstring references non-existent function (Copilot) Already fixed Docstring at kv_cache.py:19-26 now shows correct usage pattern
Batch dim hardcoded to 1 (Copilot) Documented Docstring at kv_cache.py:56-57 explicitly states "Batch size is fixed to 1"
Mask head-dim validation (Copilot) Not addressed Kernel still only offsets by batch, no validation that attn_mask.size(1) == 1
Dummy mask pointer (Claude #1) Addressed Now uses Mask_ptr = 0 with a documenting comment (tq4_sdpa.py:321), matching sdpa.py's pattern

Holistic Review

Taking a step back from individual line items, this is my assessment of the PR as a complete contribution.

Architecture & Design

The PR adds TQ4 KV cache compression as a composable, opt-in feature. The design has three well-separated layers:

  1. TurboQuantKVCache (pure PyTorch nn.Module) — handles compression/decompression, state management, and export compatibility. No CUDA dependency.
  2. tq4_sdpa (Triton kernel wrapped as @triton_op) — fused attention over compressed data. Follows the existing sdpa.py registration pattern.
  3. Model integration (flag-based branch in FullAttention.forward) — clean opt-in via self.turboquant boolean.

This layering is good. Each component is testable in isolation, and the integration surface in the model is small (one if/else branch in FullAttention.forward, model.py:270-288).

The _apply_turboquant function in export.py:351-374 runs after _materialize_buffers, which is the correct ordering — TQ4 buffers need their real dtypes (uint8, fp32), not the blanket bf16 cast.

Kernel Correctness

The Triton kernel (tq4_sdpa.py:65-237) implements online softmax with inline K/V decompression. The decompression logic is:

  • Load nibble-packed uint8 → split into high/low 4-bit indices
  • tl.join + .reshape to interleave → centroid lookup → scale by norms

This correctly inverts the packing (even << 4) | odd. The tl.join produces [BLOCK_N, HALF_D_PAD, 2] which reshapes to [BLOCK_N, HEAD_DIM_PAD] with interleaved [even_0, odd_0, even_1, odd_1, ...] — matching the original dimension order.

The online softmax follows the standard two-pass-free algorithm with exp2 and log2(e) scaling (qk_scale = sm_scale * 1.44269504 at line 148).

One subtlety: the acc = tl.dot(p.to(v.dtype), v, acc).to(tl.float32) at line 225 uses tl.dot's third argument for in-place accumulation. The .to(tl.float32) cast after tl.dot is needed for Triton 3.6 compatibility where tl.dot may return in the input dtype rather than fp32. This is documented in the file header.

Test Coverage

Coverage is comprehensive and well-structured:

  • Kernel tests (test_tq4_sdpa.py): MHA, GQA (2x/4x/8x), various head dims (64/128/256), decode and prefill sequence lengths, bool masks, causal masks, batch > 1, Qwen 3.5 MoE-specific config (16:2 GQA, head_dim=256), export + C++ runner e2e.
  • KV cache module tests (test_turboquant_kv_cache.py): numerical equivalence with turboquant-vllm, nibble packing roundtrip exhaustive (all 256 pairs), bf16 tolerance, torch.export(strict=True) compatibility, state accumulation, edge cases.
  • Integration tests (test_turboquant.py): eager quality, decode quality, export-matches-eager, cache state matters, replacement count.

All test files have proper skip guards for the optional turboquant-vllm dependency.

Remaining Issues

1. Unconditional turboquant_vllm import in __init__.py — This is the one item from both previous reviews that's not addressed. extension/llm/modules/turboquant/__init__.py:7 does a bare from ... import TurboQuantKVCache which triggers from turboquant_vllm import solve_lloyd_max at module level. If anyone imports executorch.extension.llm.modules.turboquant without having turboquant-vllm installed, they get an unhelpful ImportError: No module named 'turboquant_vllm'. A lazy import or try/except with a clear error message would be better. Fix this →

2. Mask head-dimension not validated — The kernel accesses the mask with mask_b_base + offs_m[:, None] * stride_mq + (start_n + offs_n[None, :]) * stride_mk where mask_b_base = Mask + off_z * stride_mb. This skips the head dimension entirely (no off_h_q * stride_mh term). If a caller passes a per-head mask [B, H_Q, L_Q, L_KV] instead of [B, 1, L_Q, L_KV], the kernel silently reads only the first head's mask for all heads. The existing sdpa.py validates attn_mask.shape[1] == 1 in _prepare_mask_params. Adding the same check in tq4_sdpa would prevent silent misbehavior. Fix this →

3. Input dtype/device validation — The existing sdpa.py validates that all tensors are CUDA and bf16 via _validate_sdpa_inputs(). The tq4_sdpa wrapper does not validate dtype or device. Since the kernel will silently produce wrong results on CPU tensors or float32 inputs, consider adding basic guards. Not a blocker — callers in this PR always pass correct types.

4. Pre/post rotation cost — The Q @ Pi^T (line 302) and out @ Pi (line 360) rotations outside the kernel are O(B·H·L·D²) matmuls. For the Qwen 3.5 MoE config (D=256), these are 256x256 matmuls per head per token. During prefill with long sequences, this could be a meaningful fraction of the total compute. The code has a comment acknowledging this ("a future optimization could fuse rotation into the kernel"). Worth profiling, but not a blocker.

5. is_causal hardcoded False in wrapper call site — In model.py:274, tq4_sdpa is called without is_causal (defaults to False), relying on the explicit attn_mask from the causal mask buffer. This works but means the IS_CAUSAL kernel path is never exercised in production. The tests do cover is_causal=True separately. Fine as-is.

What Looks Good

  • The kernel registration follows the established @triton_op pattern from sdpa.py, making it a natural extension of the CUDA backend.
  • TurboQuantKVCache is a drop-in replacement for KVCache with the same update() interface signature (returns cache state tuples), just with different return types (packed + norms vs raw tensors).
  • The docstring on kv_cache.py correctly references the actual integration path in examples/models/qwen3_5_moe/export.py rather than a nonexistent helper function.
  • AOTI shim additions (uint8 dtype) are minimal and correct — ScalarType::Byte = 0.
  • CI changes are scoped: install turboquant-vllm==1.4.0, run the new test files. No changes to unrelated test jobs.
  • The export artifact script enables --turboquant by default for Qwen 3.5 MoE exports, which makes sense as the primary deployment target.

Summary

This is a well-structured PR. The kernel is correct, test coverage is thorough, and the integration is clean. The main actionable item is the unconditional turboquant_vllm import in __init__.py (#1 above), which will cause confusing errors for users who don't have the optional dependency. The mask validation gap (#2) is a secondary concern. Everything else is informational.


View job

Copilot AI review requested due to automatic review settings April 3, 2026 19:32
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from 7fb707a to 2338313 Compare April 3, 2026 19:32
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 393 to 401
# turboquant-vllm requires Python >=3.12; upgrade if needed
PYTHON_MINOR=$(python -c "import sys; print(sys.version_info.minor)")
if [ "$PYTHON_MINOR" -lt 12 ]; then
echo "Upgrading Python to 3.12 (turboquant-vllm requires >=3.12)"
conda install -y python=3.12
fi
pip install safetensors huggingface_hub
pip install turboquant-vllm==1.4.0
pip install -r examples/models/qwen3_5_moe/requirements.txt
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upgrading Python in-place (conda install -y python=3.12) inside the artifact export script can break the already-installed executorch package and any compiled extensions in the environment (new interpreter / site-packages), which would make the subsequent python -m executorch... invocations fail. Consider ensuring the job/environment starts on Python 3.12 (or creating/activating a dedicated conda env for Qwen3.5 MoE export) rather than mutating the interpreter mid-script.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 3, 2026 20:33 — with GitHub Actions Inactive
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from 2338313 to aa9190d Compare April 3, 2026 20:45
Copilot AI review requested due to automatic review settings April 3, 2026 20:46
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from aa9190d to c17c366 Compare April 3, 2026 20:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 14 out of 14 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +279 to +321
B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

if H_Q % H_KV != 0:
raise RuntimeError(
f"H_Q must be a multiple of H_KV for GQA head mapping, "
f"got H_Q={H_Q}, H_KV={H_KV}"
)
if HALF_D * 2 != D:
raise RuntimeError(
f"k_packed last dim ({HALF_D}) * 2 must equal query head_dim ({D})"
)

sm_scale = 1.0 / math.sqrt(D)

# Reshape norms: [B, H, S, 1] -> [B, H, S]
k_n = k_norms.reshape(B, H_KV, N_KV).contiguous()
v_n = v_norms.reshape(B, H_KV, N_KV).contiguous()

# Pre-rotate Q: Q_rot = Q @ Pi^T
# O(B·H·L·D²) matmul in float32. For long-sequence prefill this cost
# may be non-trivial; a future optimization could fuse rotation into
# the kernel.
q_rot = torch.matmul(query.float(), rotation.T).to(query.dtype).contiguous()

out_rot = torch.empty_like(query)

HEAD_DIM_PAD = _next_pow2(D)
HALF_D_PAD = _next_pow2(HALF_D)

HAS_MASK = attn_mask is not None
if is_causal and N_Q != N_KV:
raise RuntimeError(
f"is_causal requires L_Q == L_KV, got L_Q={N_Q}, L_KV={N_KV}. "
"For decode (L_Q < L_KV), use an explicit bool mask instead."
)
if HAS_MASK:
Mask_ptr = attn_mask
stride_mb = attn_mask.stride(0)
stride_mq = attn_mask.stride(2)
stride_mk = attn_mask.stride(3)
else:
Mask_ptr = 0 # dummy, never dereferenced when HAS_MASK=False
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tq4_sdpa does not validate key/value dtypes/shapes beyond head mapping. In particular, the kernel assumes k_packed/v_packed are uint8, k_norms/v_norms are float32 shaped [B,H,L,1] (or reshaped to [B,H,L]), centroids is length-16, rotation is [D,D], and attn_mask is bool [B,1,L_Q,L_KV]. Missing these checks can lead to hard-to-debug Triton crashes or silent misbehavior. Add explicit input validation and helpful error messages up front.

Copilot uses AI. Check for mistakes.
Comment on lines +151 to +155
# Run TurboQuant KV cache module tests
python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py -v -o "addopts="

# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This workflow runs TurboQuant-related tests but does not install turboquant-vllm (commented out due to Python 3.10). As a result, these tests will be skipped, so the new feature/kernel isn’t actually exercised in CI. Consider adding a separate job using Python >=3.12 (or a dedicated container) to install turboquant-vllm and run these tests, or gate these test invocations on the dependency being present to avoid giving a false sense of coverage.

Suggested change
# Run TurboQuant KV cache module tests
python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py -v -o "addopts="
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
# Run Qwen 3.5 MoE quantize roundtrip tests
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py -v -o "addopts="
# Run TurboQuant-dependent tests only when the dependency is available
if python -c "import importlib.util, sys; sys.exit(0 if importlib.util.find_spec('turboquant_vllm') else 1)"; then
python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
else
echo "Skipping TurboQuant-specific tests because turboquant-vllm is not installed in this job."
fi

Copilot uses AI. Check for mistakes.
Comment on lines +394 to +395
# turboquant-vllm requires python 3.12 or above, our CI is still on 3.10
# pip install turboquant-vllm==1.4.0
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turboquant-vllm is not installed here (commented out), which means the Qwen3.5 MoE export artifact path cannot exercise the TurboQuant codepath during CI/artifact generation. If TurboQuant support is expected to be validated in this flow, add a Python >=3.12 environment (or conditional installation) so the dependency can be installed and the export can cover the turboquant flag.

Suggested change
# turboquant-vllm requires python 3.12 or above, our CI is still on 3.10
# pip install turboquant-vllm==1.4.0
if python -c 'import sys; raise SystemExit(0 if sys.version_info >= (3, 12) else 1)'; then
pip install turboquant-vllm==1.4.0
else
echo "Skipping turboquant-vllm installation: Python 3.12+ is required."
fi

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from c17c366 to f9ade5f Compare April 3, 2026 20:58
Copilot AI review requested due to automatic review settings April 3, 2026 21:13
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from f9ade5f to 900d589 Compare April 3, 2026 21:13
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +279 to +291
B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

if H_Q % H_KV != 0:
raise RuntimeError(
f"H_Q must be a multiple of H_KV for GQA head mapping, "
f"got H_Q={H_Q}, H_KV={H_KV}"
)
if HALF_D * 2 != D:
raise RuntimeError(
f"k_packed last dim ({HALF_D}) * 2 must equal query head_dim ({D})"
)

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tq4_sdpa currently relies on implicit assumptions about tensor dtypes/shapes (e.g., k_packed/v_packed uint8, k_norms/v_norms fp32 with last dim 1, centroids length 16, attn_mask bool with shape [B,1,Lq,Lkv], all CUDA and contiguous). Other Triton ops in this repo validate inputs up front (e.g., sdpa() calls _validate_sdpa_inputs in backends/cuda/triton/kernels/sdpa.py:936+). Adding similar validation here would prevent silent wrong results or hard-to-diagnose Triton failures when called with unsupported inputs.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 3, 2026 22:18 — with GitHub Actions Inactive
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from 900d589 to d949bf9 Compare April 4, 2026 15:10
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude bot commented Apr 4, 2026

Claude finished @mergennachin's task in 2m 45s —— View job


Review of PR #18687 — TurboQuant TQ4 KV Cache Compression

  • Read all changed files
  • Review Triton kernel (tq4_sdpa.py)
  • Review KV cache module and codebook (turboquant/)
  • Review model integration (qwen3_5_moe/)
  • Review AOTI shims
  • Review tests
  • Review CI/export changes
  • Post review

Well-structured PR. The kernel, cache module, codebook solver, model integration, and tests are all consistent and correct. The previous two reviews covered the landscape well; this review focuses on confirming the current state and flagging remaining items.


Triton Kernel — backends/cuda/triton/kernels/tq4_sdpa.py

Correctness: The nibble decompression via tl.join + .reshape correctly inverts the (even << 4) | odd packing. The online softmax uses the standard exp2-based algorithm with NaN-safe guards for fully-masked tiles (tq4_sdpa.py:204-213). The norm-factoring optimization (multiply norms on the [M,N] QK/P matrices instead of the [N,D] K/V tiles) is algebraically correct and reduces FLOPs.

Pack GQA: The _should_pack_gqa heuristic (tq4_sdpa.py:44-61) and dual-kernel approach (M64 vs M32 selected at tq4_sdpa.py:522-528) follow the same pattern as sdpa.py. The threshold 4 * 84 corresponds to ~4 waves on an A100 (108 SMs), which is reasonable.

Input validation: Validation covers H_Q divisibility (tq4_sdpa.py:606-609), packed dim consistency (tq4_sdpa.py:611-613), power-of-2 head dim (tq4_sdpa.py:615-619), bool mask dtype (tq4_sdpa.py:640-644), and causal L_Q==L_KV (tq4_sdpa.py:645-649). Two gaps remain compared to sdpa.py:

  1. Missing attn_mask.shape[1] == 1 validation: The kernel only indexes the mask by [batch, seq_q, seq_kv], skipping the head dimension (tq4_sdpa.py:189-192). If a caller passes a per-head mask [B, H_Q, L_Q, L_KV], the kernel silently reads only the first head's mask for all heads. The existing sdpa.py validates this explicitly at sdpa.py:708:

    if attn_mask.shape[1] != 1:
        raise RuntimeError("attn_mask head dimension must be 1 ...")

    Adding the same check in tq4_sdpa would prevent silent misbehavior. Fix this →

  2. No dtype/device validation on Q/K/V/norms: sdpa.py validates all tensors are CUDA bf16 via _validate_sdpa_inputs (sdpa.py:666-691). tq4_sdpa does not. Passing CPU tensors or float32 Q will produce wrong results silently. The current callers always pass correct types, so this is low risk but inconsistent with the sibling kernel. Fix this →

Kernel stride passing: The launcher passes *v_norms.stride() but norms have shape [B, H, S] (3-D after reshape at tq4_sdpa.py:632), which provides (stride_vnb, stride_vnh, stride_vnn) — 3 values. The kernel body accepts 3 norm stride parameters per norm tensor (stride_knb, stride_knh, stride_knn at lines 92-94). This is correct, but the order of stride parameters at tq4_sdpa.py:544-548 skips out_rot strides before mask strides. Let me verify: *q_rot.stride() (4), *k_packed.stride() (4), *k_norms.stride() (3), *v_packed.stride() (4), *v_norms.stride() (3), *out_rot.stride() (4) — total 22 stride args. The kernel body declares: stride_qb..stride_qd (4), stride_kpb..stride_kpd (4), stride_knb..stride_knn (3), stride_vpb..stride_vpd (4), stride_vnb..stride_vnn (3), stride_ob..stride_od (4) — total 22. Matches.


KV Cache Module — extension/llm/modules/turboquant/kv_cache.py

  1. Unconditional turboquant_vllm import: This was flagged in both previous reviews and remains unaddressed. extension/llm/modules/turboquant/__init__.py:7 imports TurboQuantKVCache unconditionally, which chains to from executorch.extension.llm.modules.turboquant.codebook import .... However, looking at the current code more carefully, codebook.py does NOT import turboquant_vllm — it reimplements the Lloyd-Max solver and rotation generator from scratch. The scipy import is lazy (inside _solve_lloyd_max_cached, line 59-66). So the actual import chain is:

    __init__.pykv_cache.pycodebook.py (only needs scipy at solve time)

    This means the previous review feedback about turboquant_vllm being a hard dependency in __init__.py is wrong — the codebook module is self-contained. The turboquant_vllm dependency is only needed by the test files, not the production code path. The __init__.py import is safe for users without turboquant-vllm.

    The only actual hard dependency is scipy, which is imported lazily with a clear error message (codebook.py:63-66).

  2. Cache buffers registered as persistent: The compressed cache buffers (k_packed, k_norms, v_packed, v_norms) are registered with default persistent=True (kv_cache.py:84-99). This means they'll be serialized into state_dict/export artifacts. The existing KVCache in model.py:188-193 also uses persistent=True (default), so this is consistent with the existing pattern. However, Copilot's suggestion to use persistent=False for runtime state is worth considering for future optimization — the cache buffers are max_seq_len-sized zero tensors that don't need to be in the checkpoint. Not a blocker.

  3. Norms stored as bf16: kv_cache.py:90,98 stores norms as torch.bfloat16, and _compress returns norms cast to bf16 at line 120. The Triton kernel declares norm data type via tl.load which follows the tensor dtype. This means norm precision is bf16 (not fp32 as the docstring at tq4_sdpa.py:593-594 says). The docstring says "float or bf16" — technically correct since the kernel can handle either, but in practice the TurboQuantKVCache always produces bf16 norms. This is fine; TQ4 quantization error dominates.


Codebook — extension/llm/modules/turboquant/codebook.py

  1. Self-contained implementation: The codebook solver and rotation matrix generator are reimplemented from scratch, using only scipy.integrate.quad and scipy.stats.norm for the PDF integration. The scipy import is lazy and wrapped in a clear ImportError message (codebook.py:59-66). The lru_cache at line 55 ensures the scipy cost is paid only once across multiple layers. This is a clean design choice — no external turboquant-vllm dependency for the core module.

  2. QR sign normalization: generate_rotation_matrix (codebook.py:121-124) normalizes column signs using diag(R) to ensure deterministic QR factorizations across LAPACK implementations. This matches the standard Haar measure technique.


Model Integration — examples/models/qwen3_5_moe/model.py

  1. TQ4 branch in FullAttention.forward (model.py:270-283): Clean opt-in via self.turboquant boolean. The kernel receives self.kv_cache.centroids and self.kv_cache.rotation — these are registered buffers on the cache module, which means they'll be traced correctly by torch.export.

  2. No batch size validation: The TQ4 path doesn't check B == 1 (model.py:270). TurboQuantKVCache allocates with batch=1. If somehow called with B>1, the cache assignment self.k_packed[:, :, input_pos] = k_packed at kv_cache.py:161 would fail with a shape mismatch. The error would be clear enough from the shape mismatch, but an explicit guard would be slightly more helpful. Minor.


AOTI Shims — backends/aoti/

  1. aoti_torch_dtype_uint8: Returns 0 (ScalarType::Byte) — correct per PyTorch's c10::ScalarType. The ScalarType.h header already had Byte = 0 in the enum (ScalarType.h:26), so this just adds the missing C API accessor. The declaration in common_shims_slim.h:79 and implementation in common_shims_slim.cpp:137-139 are consistent. This was the only missing dtype shim for the types used in the TQ4 compressed cache.

Tests

  1. Kernel tests (test_tq4_sdpa.py): Thorough coverage — MHA, GQA (2x/4x/8x/MQA), head dims 64/128/256, decode and prefill, bool masks, causal masks, sparse masks, batch>1, all-masked NaN safety, validation error paths, Qwen 3.5 MoE config, export + C++ runner e2e. Cosine threshold 0.95 is appropriate for 4-bit quantized attention. The reference path (lines 89-115) compresses then decompresses in float32 before running standard SDPA, giving a clean comparison.

  2. KV cache tests (test_turboquant_kv_cache.py): CPU-only, covers roundtrip quality (>0.99 cosine), output shapes, bf16 input, state accumulation, head_dim=256, nibble packing exhaustive roundtrip (all 256 pairs), torch.export(strict=True) compatibility with dynamic shapes, codebook properties (sorted, symmetric, deterministic, varies with dim), rotation orthogonality. Good.

  3. Integration tests (test_turboquant.py): Tiny model (no downloads), eager quality, decode quality, export-matches-eager, cache state matters, replacement count. Requires CUDA. Test skip is in setUp (test_turboquant.py:107-108).

  4. Test skip guards: All three test files handle missing dependencies:

    • test_tq4_sdpa.py:146-148: _skip_if_no_cuda() in setUpClass
    • test_turboquant_kv_cache.py: No optional deps needed (codebook is self-contained)
    • test_turboquant.py:107-108: skipTest if no CUDA

    Note: test_turboquant_kv_cache.py imports TurboQuantKVCache at module level (line 22), which will chain to codebook.py. Since codebook.py only needs scipy at solve time (lazy import), the import itself won't fail even without scipy. However, any test that instantiates TurboQuantKVCache will call solve_lloyd_max in __init__, which does need scipy. If scipy is missing, the test will fail at instantiation time with a clear ImportError message rather than a skipTest. This is acceptable for CI (scipy is always available) but could be improved for local development.


CI Changes

  1. cuda.yml:148-149: The unittest-cuda job runs both test_turboquant.py and test_quantize_roundtrip.py in the same pytest invocation. The Triton kernel tests (test_tq4_sdpa.py) are in backends/cuda/tests/ which is already covered by line 146's glob backends/cuda/tests. So TQ4 kernel tests run in CI.

  2. export_model_artifact.sh:391-430: The Qwen 3.5 MoE export path (MODEL_NAME="qwen3_5_moe") downloads a prequantized checkpoint, runs inference sanity check, and exports with --turboquant. No turboquant-vllm is needed here because the production code path only needs scipy (via codebook.py), not turboquant-vllm. The EXTRA_PIP="" is correct.

  3. Missing turboquant-vllm in CI for tests: The unittest-cuda job at cuda.yml:121-149 does NOT install turboquant-vllm. The test_turboquant.py integration test imports _apply_turboquant which imports TurboQuantKVCache which only needs scipy/codebook.py. So this works. The kernel test_tq4_sdpa.py imports TurboQuantKVCache and solve_lloyd_max/generate_rotation_matrix from codebook.py — also works without turboquant-vllm. All production tests run without turboquant-vllm. The only tests that needed turboquant-vllm were the library compatibility tests in the KV cache test file, and those appear to have been removed (I see no turboquant_vllm import in any test file).


Kernel Init Import — backends/cuda/triton/kernels/__init__.py

  1. Graceful fallback: The TQ4 kernel import is wrapped in try/except ImportError (__init__.py:26-31), matching the existing pattern for chunk_gated_delta_rule. This means importing the kernels package never fails even if triton isn't available. Clean.

Items From Previous Reviews — Status

Item Status
Unconditional turboquant_vllm import Not an issue — codebook.py is self-contained; turboquant_vllm is not imported anywhere in production code
Mask head-dim validation Still missing — see item 1 above
dtype/device validation Still missing — see item 2 above
bits != 4 validation Already handled (kv_cache.py:63-66)
GQA validation Already handled (tq4_sdpa.py:606-613)
Test skip guards Already handled in all test files
Dummy mask pointer Fixed — uses Mask_ptr = 0 (tq4_sdpa.py:656)

Summary

The PR is correct and well-organized. The codebook module being self-contained (no turboquant-vllm dependency) is a significant improvement over what the previous reviews assumed — the only optional dependency is scipy, imported lazily with a clear error.

Two actionable items:

  1. Add attn_mask.shape[1] == 1 validation in tq4_sdpa to match sdpa.py, preventing silent per-head mask bugs.
  2. Add dtype/device validation on input tensors to match sdpa.py's _validate_sdpa_inputs pattern.

Everything else is clean.


View job

@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 4, 2026 16:16 — with GitHub Actions Inactive
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude How would you compare against other repos like https://github.com/cksac/turboquant-model, https://github.com/0xSero/turboquant or https://github.com/tonbistudio/turboquant-pytorch

@claude
Copy link
Copy Markdown

claude bot commented Apr 4, 2026

Claude finished @mergennachin's task in 4m 42s —— View job


Comparing TurboQuant implementations


I was unable to access those three repositories (sandbox network restrictions blocked external GitHub API calls and web fetches), so I can't confirm whether they exist or what they contain.

However, I'd flag caution: none of those three repositories are referenced anywhere in the TurboQuant paper (arXiv 2504.19874), in this PR, or in the known legitimate implementation. The only external implementation referenced by this PR's code is Alberto-Codes/turboquant-vllm (Apache 2.0), which is explicitly cited in both:

  • backends/cuda/triton/kernels/tq4_sdpa.py lines 9-13 (decompression logic adapted from turboquant-vllm v1.4.0)
  • extension/llm/modules/turboquant/codebook.py line 8 (Lloyd-Max solver and rotation generator adapted from turboquant-vllm)

Name-squatting on popular paper names is a known phenomenon on GitHub. The TurboQuant paper is recent (ICLR 2026), and repositories with generic names like turboquant, turboquant-model, or turboquant-pytorch that aren't affiliated with the paper authors or established ML frameworks should be treated with skepticism until verified. Common red flags would include: minimal commit history, no actual implementation of the paper's algorithm, or suspiciously timed creation dates right after the paper's release.

What this PR actually implements vs. what a legitimate TurboQuant implementation should contain

For reference, this PR's implementation includes the core components described in the paper:

Component This PR What the paper describes
Lloyd-Max optimal scalar quantizer codebook.py — self-contained solver using scipy for PDF integration 4-bit (16-entry) codebook optimized for Gaussian marginals after rotation
Haar-distributed random rotation codebook.py:112-124 — QR decomposition with sign normalization Orthogonal rotation to make coordinates approximately i.i.d. Gaussian
Nibble packing kv_cache.py:107-113(even << 4) | odd into uint8 Pack two 4-bit indices per byte for 2x storage density
Norm factoring kv_cache.py:100-105 — unit-normalize, store bf16 norms separately Separate magnitude from direction for better quantization
Fused attention kernel tq4_sdpa.py — Triton kernel decompresses K/V per-tile in the attention inner loop Avoid materializing full decompressed cache
Integration with model model.py:270-288 — opt-in branch in FullAttention.forward Drop-in replacement for standard KV cache

A legitimate third-party implementation should have most of these components and should produce numerically consistent results with the paper's algorithm. Repos that only contain packaging metadata, vague READMEs, or code unrelated to vector quantization of KV caches are unlikely to be genuine implementations.


Copilot AI review requested due to automatic review settings April 4, 2026 22:35
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from d949bf9 to 613c0cb Compare April 4, 2026 22:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 17 out of 17 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +568 to +580
def _validate_tq4_inputs(query, k_packed, v_packed):
"""Validate tensor shapes, dtypes, and device for tq4_sdpa."""
B, H_Q, N_Q, D = query.shape
B_kp, H_KV, N_KV, HALF_D = k_packed.shape

if not query.is_cuda:
raise RuntimeError("query must be a CUDA tensor")
if query.dtype != torch.bfloat16:
raise RuntimeError(f"query must be bfloat16, got {query.dtype}")
if query.dim() != 4:
raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D")
if k_packed.dim() != 4 or v_packed.dim() != 4:
raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]")
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function unpacks query.shape and k_packed.shape before verifying dim() == 4. If a caller passes a tensor with the wrong rank, this will raise a ValueError during tuple unpacking instead of the intended RuntimeError. Move the dim() checks (for query, k_packed, v_packed) before any shape unpacking so invalid inputs fail with the expected error message.

Copilot uses AI. Check for mistakes.
Comment on lines +667 to +685
_validate_tq4_inputs(query, k_packed, v_packed)

B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

_validate_tq4_mask(attn_mask, B, N_Q, N_KV)

sm_scale = 1.0 / math.sqrt(D)
num_groups = H_Q // H_KV

# Build [256] bf16 lookup tables from [16] centroids.
# In the export path, inductor fuses this into the compiled graph.
all_bytes = torch.arange(256, device=centroids.device)
lut_hi = centroids[(all_bytes >> 4).long()].to(query.dtype).contiguous()
lut_lo = centroids[(all_bytes & 0x0F).long()].to(query.dtype).contiguous()

# Reshape norms: [B, H, S, 1] -> [B, H, S]
k_n = k_norms.reshape(B, H_KV, N_KV).contiguous()
v_n = v_norms.reshape(B, H_KV, N_KV).contiguous()
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tq4_sdpa assumes centroids, rotation, k_norms, and v_norms are CUDA tensors on the same device as query. Currently only query/k_packed/v_packed/attn_mask are validated. If (for example) centroids is on CPU, all_bytes/LUTs will be created on CPU and the Triton kernel launch will fail. Add explicit validation that centroids.is_cuda, rotation.is_cuda, k_norms.is_cuda, and v_norms.is_cuda (and ideally device matches query.device), plus basic shape checks for centroids ([16]) and rotation ([D, D]) to fail fast with a clear error.

Copilot uses AI. Check for mistakes.
Comment on lines +687 to +692
# Pre-rotate Q: Q_rot = Q @ Pi^T (bf16 — TQ4 error dominates)
q_rot = torch.matmul(query, rotation.T.to(query.dtype)).contiguous()

out_rot = torch.empty_like(query)

HAS_MASK = attn_mask is not None
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does rotation.T.to(query.dtype) on every call, which creates a new converted tensor each time (and similarly rotation.to(query.dtype) at return). Consider converting once per call (e.g., rot = rotation.to(query.dtype); rot_t = rot.T) and reusing those tensors for the pre/post matmuls to reduce overhead in decode loops.

Copilot uses AI. Check for mistakes.
"""KV cache with TQ4 compression.

Stores K/V as nibble-packed uint8 indices (2 indices per byte) plus
fp32 per-vector norms. The ``update()`` method compresses incoming
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states norms are stored as fp32, but the implementation registers k_norms/v_norms as torch.bfloat16 and _compress() returns bf16 norms. Please update the docstring to match the actual dtype (bf16), or switch the buffers/returned norms to fp32 if fp32 is required for correctness.

Suggested change
fp32 per-vector norms. The ``update()`` method compresses incoming
bf16 per-vector norms. The ``update()`` method compresses incoming

Copilot uses AI. Check for mistakes.
Comment on lines +63 to +65
def test_roundtrip_quality(self):
cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS)
x = torch.randn(1, N_HEADS, 10, HEAD_DIM)
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TurboQuantKVCache construction calls solve_lloyd_max(), which raises ImportError if SciPy is not installed. If the repo's default test environment doesn't include SciPy, these tests will fail even when TurboQuant is not used elsewhere. Consider skipping these tests when SciPy isn't available (e.g., via an import check / pytest importorskip equivalent for unittest), or providing a non-SciPy fallback path for codebook initialization in test mode.

Copilot uses AI. Check for mistakes.
RUNNER_TARGET="qwen3_5_moe_runner"
RUNNER_PATH="qwen3_5_moe"
EXPECTED_OUTPUT="Paris"
EXPECTED_OUTPUT="the"
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing EXPECTED_OUTPUT from a specific token ("Paris") to a very common token ("the") significantly weakens the e2e correctness check and increases the chance of false positives. Also, the longer prompt and --max_new_tokens 128 likely increase CI runtime and flakiness. Consider keeping a correctness-oriented prompt/expected output for CI (short + specific), and gating the longer perf-style prompt behind an opt-in flag/environment variable.

Suggested change
EXPECTED_OUTPUT="the"
EXPECTED_OUTPUT="Paris"

Copilot uses AI. Check for mistakes.
Comment on lines +357 to +359
# pp512/tg128 matches llama-bench defaults for comparable perf numbers
PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)")
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0"
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing EXPECTED_OUTPUT from a specific token ("Paris") to a very common token ("the") significantly weakens the e2e correctness check and increases the chance of false positives. Also, the longer prompt and --max_new_tokens 128 likely increase CI runtime and flakiness. Consider keeping a correctness-oriented prompt/expected output for CI (short + specific), and gating the longer perf-style prompt behind an opt-in flag/environment variable.

Suggested change
# pp512/tg128 matches llama-bench defaults for comparable perf numbers
PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)")
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0"
if [ "${QWEN3_5_MOE_PERF_PROMPT:-0}" = "1" ]; then
# Opt-in perf-style prompt for local benchmarking/stress runs.
# pp512/tg128 matches llama-bench defaults for comparable perf numbers.
PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)")
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0"
else
# Keep CI e2e focused on correctness: short prompt, specific expected token,
# and small generation length to reduce runtime and flakiness.
EXPECTED_OUTPUT="Paris"
RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'The capital of France is' --max_new_tokens 16 --temperature 0"
fi

Copilot uses AI. Check for mistakes.
Comment on lines +418 to +422
TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \
python -m executorch.examples.models.qwen3_5_moe.export \
--prequantized "$LOCAL_MODEL_DIR" \
--output-dir "${OUTPUT_DIR}"
--output-dir "${OUTPUT_DIR}" \
--turboquant
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR introduces --turboquant as an optional flag, but the artifact export script now always enables it for Qwen3.5 MoE. If the intention is to keep baseline artifacts unchanged by default, consider producing a separate TurboQuant artifact (or gating this behind an env var) to avoid silently changing the default exported model behavior/perf characteristics.

Copilot uses AI. Check for mistakes.
@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 4, 2026 23:49 — with GitHub Actions Inactive
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from 613c0cb to f325eee Compare April 4, 2026 23:54
Add TurboQuant (arXiv 2504.19874) KV cache compression to the CUDA
backend, reducing KV cache memory 3.8x by storing nibble-packed uint8
indices + bf16 norms instead of bf16 tensors. A fused Triton SDPA
kernel decompresses K/V per-tile in the attention inner loop so the
full cache is never materialized.

Components:

- backends/cuda/triton/kernels/tq4_sdpa.py: Fused TQ4 Flash Attention
  kernel with Pack GQA optimization (adapted from sdpa.py structure),
  precomputed [256]-entry bf16 LUT for zero-bit-op centroid gather,
  and norm factoring (norms multiplied on [M,N] QK/P matrices instead
  of [N,D] K/V tiles). NaN-safe softmax guards for sparse masks.
  Registered as @triton_op for torch.export + CUDA backend lowering.

- extension/llm/modules/turboquant/: TurboQuantKVCache nn.Module with
  bf16 compression path and self-contained Lloyd-Max codebook solver
  (no external dependencies; scipy used lazily for codebook init only).

- examples/models/qwen3_5_moe/: --turboquant flag in export.py, branch
  in FullAttention.forward() between triton::sdpa and tq4_sdpa. Both
  paths use filled-length KV cache slicing (PAD(filled, 256)) to skip
  empty positions, enabling 200K context without O(n^2) mask or linear
  decode slowdown. O(n) causal mask computed on the fly from
  cache_positions, replacing the O(n^2) pre-materialized mask.

- backends/cuda/triton/kernels/sdpa.py: guard_or_false for L_kv_k !=
  L_kv_v check, enabling export with data-dependent KV cache slicing.

- backends/aoti/: Added aoti_torch_dtype_uint8 shim and Byte ScalarType
  to slim headers (required for uint8 KV cache tensors in C++ runtime).

Performance (A100, Qwen 3.5 MoE, B=1, GQA 16:2, D=256, seq=4096):
  TQ4 SDPA kernel: 0.66ms (was 3.74ms before optimizations)
  Baseline bf16:   0.45ms (1.5x overhead for 3.8x memory savings)
  Full AOTI path:  0.79ms (inductor fuses compress ops)
  E2e decode:      59.8 tok/s (vs 70.6 baseline, 20.6 original)
Copilot AI review requested due to automatic review settings April 4, 2026 23:57
@mergennachin mergennachin force-pushed the mergennachin/turboquant branch from f325eee to 94973a3 Compare April 4, 2026 23:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +603 to +619
def _validate_tq4_mask(attn_mask, B, N_Q, N_KV):
"""Validate attention mask for tq4_sdpa."""
if attn_mask is None:
return
if attn_mask.dtype != torch.bool:
raise RuntimeError(
f"attn_mask must be bool, got {attn_mask.dtype}. "
"Additive float masks are not supported."
)
if not attn_mask.is_cuda:
raise RuntimeError("attn_mask must be a CUDA tensor")
if attn_mask.shape[1] != 1:
raise RuntimeError(
f"attn_mask head dimension must be 1 (broadcast over heads); "
f"per-head masks are not supported. "
f"Got attn_mask.shape={attn_mask.shape}"
)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_tq4_mask() indexes attn_mask.shape[1], [2], and [3] without first validating attn_mask.dim() == 4. For masks with the wrong rank, this will raise an IndexError instead of the intended RuntimeError. Add an explicit rank check early (e.g., require 4D) before accessing shape indices.

Copilot uses AI. Check for mistakes.
Comment on lines +667 to +689
_validate_tq4_inputs(query, k_packed, v_packed)

B, H_Q, N_Q, D = query.shape
_, H_KV, N_KV, HALF_D = k_packed.shape

_validate_tq4_mask(attn_mask, B, N_Q, N_KV)

sm_scale = 1.0 / math.sqrt(D)
num_groups = H_Q // H_KV

# Build [256] bf16 lookup tables from [16] centroids.
# In the export path, inductor fuses this into the compiled graph.
all_bytes = torch.arange(256, device=centroids.device)
lut_hi = centroids[(all_bytes >> 4).long()].to(query.dtype).contiguous()
lut_lo = centroids[(all_bytes & 0x0F).long()].to(query.dtype).contiguous()

# Reshape norms: [B, H, S, 1] -> [B, H, S]
k_n = k_norms.reshape(B, H_KV, N_KV).contiguous()
v_n = v_norms.reshape(B, H_KV, N_KV).contiguous()

# Pre-rotate Q: Q_rot = Q @ Pi^T (bf16 — TQ4 error dominates)
q_rot = torch.matmul(query, rotation.T.to(query.dtype)).contiguous()

Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tq4_sdpa() implicitly assumes centroids, rotation, k_norms, and v_norms are CUDA tensors on the same device as query (e.g., it builds LUTs on centroids.device and matmuls with rotation). If any of these are on CPU or a different CUDA device, the function will fail with less clear downstream errors. Consider extending validation to require these tensors are CUDA, on query.device, and have expected dtypes/shapes (e.g., centroids length 16, rotation (D,D), norms shape [B,H,L] or [B,H,L,1]).

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +66
try:
from scipy import integrate
from scipy.stats import norm
except ImportError:
raise ImportError(
"scipy is required for TurboQuant codebook computation. "
"Install it with: pip install scipy"
)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solve_lloyd_max() hard-requires SciPy (integrate + stats.norm) and raises ImportError if it's missing, but this repository's Python requirements files don't list SciPy. As-is, instantiating TurboQuantKVCache (and tests that depend on it) will fail in environments without SciPy. Consider providing a SciPy-free fallback (e.g., precomputed centroids/boundaries for common head_dims) and/or skipping TurboQuant tests when SciPy isn't available, plus documenting/adding the optional dependency where appropriate.

Copilot uses AI. Check for mistakes.
Comment on lines +342 to 347
# Recompute cache_positions for full attention layers
for layer in model.layers:
if hasattr(layer.attn, "mask"):
mask = torch.tril(
torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool)
if hasattr(layer.attn, "cache_positions"):
layer.attn.cache_positions = torch.arange(
config.max_seq_len, dtype=torch.long
)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_materialize_buffers()'s docstring still says it "recomputes RoPE inv_freq and causal masks", but the implementation now recomputes cache_positions (and no longer rebuilds a mask buffer). Please update the docstring (above) to match the new behavior so readers don't expect mask to exist.

Copilot uses AI. Check for mistakes.
Comment on lines 231 to 237
self.kv_cache = KVCache(self.n_kv_heads, self.head_dim, config.max_seq_len)
self.turboquant = False

mask = torch.tril(
torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool)
self.register_buffer(
"cache_positions",
torch.arange(config.max_seq_len, dtype=torch.long),
)
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FullAttention no longer registers a mask buffer, but elsewhere in this file (e.g., from_hf_checkpoint()'s runtime_prefixes) still treats .mask as an expected missing runtime key. With this change, .cache_positions is likely to be reported as an unexpected missing key (warning), potentially hiding real mismatches. Update the missing-key allowlist to include .cache_positions instead of .mask.

Copilot uses AI. Check for mistakes.
Comment on lines +363 to +365
_, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape
layer.attn.kv_cache = TurboQuantKVCache(
n_heads,
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _apply_turboquant(), _, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape uses n_heads to mean the number of KV heads (the second dimension of k_cache). Renaming this to something like n_kv_heads would make the intent clearer and avoid confusion with the model's query head count.

Suggested change
_, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape
layer.attn.kv_cache = TurboQuantKVCache(
n_heads,
_, n_kv_heads, max_seq_len, head_dim = old_cache.k_cache.shape
layer.attn.kv_cache = TurboQuantKVCache(
n_kv_heads,

Copilot uses AI. Check for mistakes.
Comment on lines +75 to +82
centroids, boundaries = solve_lloyd_max(head_dim, bits)
rotation = generate_rotation_matrix(head_dim, seed=seed)

self.register_buffer("centroids", centroids)
self.register_buffer("boundaries", boundaries.to(torch.bfloat16))
self.register_buffer("rotation", rotation)
self.register_buffer("rotation_T", rotation.T.to(torch.bfloat16).contiguous())

Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TurboQuantKVCache.init always calls generate_rotation_matrix(head_dim, seed), and _apply_turboquant() uses the default seed for every layer. That means Qwen 3.5 MoE will recompute the same (D×D) QR-based rotation matrix many times during export/init. Consider caching generate_rotation_matrix() (similar to solve_lloyd_max) or allowing callers to pass a precomputed rotation to avoid redundant O(D^3) work.

Copilot uses AI. Check for mistakes.
Comment on lines +573 to +593
if not query.is_cuda:
raise RuntimeError("query must be a CUDA tensor")
if query.dtype != torch.bfloat16:
raise RuntimeError(f"query must be bfloat16, got {query.dtype}")
if query.dim() != 4:
raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D")
if k_packed.dim() != 4 or v_packed.dim() != 4:
raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]")
if k_packed.dtype != torch.uint8 or v_packed.dtype != torch.uint8:
raise RuntimeError("k_packed and v_packed must be uint8")
if B_kp != B:
raise RuntimeError(
f"Batch dim mismatch: query has B={B}, k_packed has B={B_kp}"
)
if H_Q % H_KV != 0:
raise RuntimeError(
f"H_Q must be a multiple of H_KV for GQA head mapping, "
f"got H_Q={H_Q}, H_KV={H_KV}"
)
if HALF_D * 2 != D:
raise RuntimeError(
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_tq4_inputs() doesn't verify that k_packed/v_packed are CUDA tensors (only query is checked) and doesn't check that v_packed matches k_packed's shape (B/H/L/D//2). Passing CPU packed tensors or mismatched K/V shapes will lead to confusing kernel-launch or out-of-bounds errors. Add explicit device + shape compatibility checks for K/V (and ideally norms) here.

Copilot uses AI. Check for mistakes.
@mergennachin
Copy link
Copy Markdown
Contributor Author

@claude Review this PR

@claude
Copy link
Copy Markdown

claude bot commented Apr 5, 2026

Claude finished @mergennachin's task in 4m 0s —— View job


Review of PR #18687 — TurboQuant TQ4 KV Cache Compression

  • Check for new commits since last review
  • Read all changed files in current state
  • Review Triton kernel (tq4_sdpa.py)
  • Review KV cache module and codebook (turboquant/)
  • Review model integration (qwen3_5_moe/)
  • Review AOTI shims and CI changes
  • Review tests
  • Post final review

No new commits since the last review (still single commit 94973a3). The code has evolved substantially from the version reviewed earlier — many items from prior reviews (mask validation, shape validation, dtype checks, dummy mask pointer) are now addressed. This review covers the current state of all 16 changed files.


Triton Kernel — backends/cuda/triton/kernels/tq4_sdpa.py

Correctness confirmed. The nibble decompression via tl.join + .reshape correctly inverts the (even << 4) | odd packing. The online softmax uses exp2-based algorithm with NaN-safe guards for fully-masked tiles (lines 204-213). The norm-factoring optimization — multiply norms on [M,N] QK/P matrices instead of [N,D] K/V tiles — is algebraically correct: Q @ (C·n)^T = (Q @ C^T) · n^T for K (line 186), and P @ (C·n) = (P·n) @ C for V (line 228).

Validation is now thorough. _validate_tq4_inputs (lines 568-600) checks CUDA device, bf16 dtype, 4D rank, batch dim match, GQA divisibility, packed dim consistency, and power-of-2 head dim. _validate_tq4_mask (lines 603-629) checks bool dtype, CUDA device, head dim == 1, and B/Lq/Lkv shape match. This matches the sdpa.py validation pattern.

One minor validation gap remains:

  1. Shape unpacking before rank check (tq4_sdpa.py:570-571): _validate_tq4_inputs does B, H_Q, N_Q, D = query.shape and B_kp, H_KV, N_KV, HALF_D = k_packed.shape before the dim() != 4 check at lines 577-580. If a caller passes a 3D tensor, the unpacking itself raises a ValueError from tuple mismatch, not the intended RuntimeError. The rank check should come first. This is a cosmetic issue — the error is still raised, just with a different message.

Observations (non-blocking):

  1. Rotation conversion on every call (tq4_sdpa.py:688,741): rotation.T.to(query.dtype) and rotation.to(query.dtype) are called on every forward pass. For decode (called per token), the overhead of creating transposed + dtype-converted copies is small but nonzero. Caching rotation_T as a buffer (which TurboQuantKVCache already does at kv_cache.py:81) would avoid this. The kernel wrapper doesn't use the cache's pre-computed rotation_T — it takes rotation as input and converts/transposes itself. This means the pre-rotation matmul in the kernel is query @ rotation.T.to(query.dtype) while the cache has self.rotation_T (already transposed + bf16). Not a bug, but a minor efficiency miss.

  2. Pack GQA heuristic (tq4_sdpa.py:44-61): The _should_pack_gqa decision uses the same FlashAttention heuristic (4 × 84 CTA threshold) as sdpa.py. The dual-kernel approach (M64 vs M32, selected at lines 522-528) is well-motivated.


KV Cache Module — extension/llm/modules/turboquant/kv_cache.py

  1. Self-contained dependency chain. The import chain is __init__.pykv_cache.pycodebook.py. The codebook module is self-contained — it only uses scipy (lazy import with clear ImportError message at codebook.py:62-66). No turboquant-vllm dependency in production code. The earlier review concern about unconditional turboquant_vllm import was incorrect — the codebook was vendored/reimplemented.

  2. Persistent cache buffers (kv_cache.py:84-99): Cache buffers (k_packed, k_norms, v_packed, v_norms) are registered with default persistent=True. The existing KVCache in model.py:188-193 also uses persistent=True, so this is consistent. However, the KVCache in the standard extension/llm/modules/kv_cache.py uses persistent=False. For export artifacts, persistent=True means these max_seq_len-sized zero buffers are serialized into the state dict. Since they're runtime state (overwritten on first use), persistent=False would reduce artifact size. Not blocking since the existing model-local KVCache uses the same pattern.

  3. Docstring says "fp32 norms" but implementation uses bf16 (kv_cache.py:43): The class docstring says "bf16 per-vector norms" which matches the implementation (torch.bfloat16 at lines 90, 98). The tq4_sdpa docstring at tq4_sdpa.py:656-657 says "float or bf16" — accurate since the kernel can handle either dtype via tl.load.

  4. _compress stays in bf16 (kv_cache.py:108): The compression path casts input to self.rotation_T.dtype (bf16) and works entirely in bf16. The docstring correctly documents this choice: "TQ4 quantization error dominates bf16 rounding." The codebook and boundaries are also cast to bf16 (boundaries.to(torch.bfloat16) at line 79). This means torch.bucketize at line 113 operates on bf16 tensors. Since boundaries are bf16 and rotated values are bf16, the bucketization is consistent.


Codebook — extension/llm/modules/turboquant/codebook.py

  1. Lloyd-Max solver correctness: Uses Gaussian approximation N(0, 1/d) which is accurate for d >= 64 (the minimum head_dim in practice). The lru_cache at line 55 ensures multi-layer models pay the scipy cost only once. The convergence criterion (max_shift < tol) is standard.

  2. Rotation matrix: generate_rotation_matrix (codebook.py:119-124) uses QR decomposition with diagonal sign normalization to produce a deterministic Haar-distributed orthogonal matrix. The sign normalization (diag_sign[diag_sign == 0] = 1.0) handles the degenerate case where a diagonal element of R is exactly zero.

  3. No caching on generate_rotation_matrix: Unlike solve_lloyd_max which has @lru_cache, generate_rotation_matrix is called once per layer during TurboQuantKVCache.__init__. For Qwen 3.5 MoE (10 full-attention layers, D=256), this means 10 × O(D³) QR decompositions at init time. Adding an @lru_cache keyed on (dim, seed) would avoid redundant computation when all layers share the same (dim, seed). Minor optimization.


Model Integration — examples/models/qwen3_5_moe/model.py

  1. Clean opt-in (model.py:274-285): The self.turboquant boolean flag set at line 232 cleanly branches between standard SDPA and TQ4. The kernel receives self.kv_cache.centroids and self.kv_cache.rotation — registered buffers that trace correctly through torch.export.

  2. Mask construction (model.py:268-272): The attention mask is constructed as cache_positions <= input_pos[:, None] with shape [1, 1, Lq, max_seq_len]. This is [B=1, H=1, Lq, Lkv] — the H=1 dimension passes the kernel's attn_mask.shape[1] == 1 validation.

  3. runtime_prefixes includes stale .mask (model.py:597): The from_hf_checkpoint method's runtime_prefixes tuple includes ".mask" but FullAttention no longer registers a mask buffer (it was replaced with cache_positions). The cache_positions buffer is already covered by .cache_positions not being in the list — wait, looking again, runtime_prefixes doesn't include .cache_positions. Let me check: the tuple at line 592-599 is ("lm_head.weight", ".kv_cache.", ".conv_state", ".recurrent_state", ".mask", ".inv_freq"). Since cache_positions is not listed, it will appear as an "unexpected missing key" in the warning at line 606. This is a minor issue — the warning is non-fatal and only appears during from_hf_checkpoint loading. Fix this →

  4. _materialize_buffers docstring (export.py:311-316): Says "recomputes RoPE inv_freq and causal masks" but the function now recomputes cache_positions (lines 343-347) instead of masks. The docstring should be updated to match. Fix this →


AOTI Shims — backends/aoti/

  1. aoti_torch_dtype_uint8: Returns 0 (ScalarType::Byte) — correct per ScalarType.h:26. Declaration at common_shims_slim.h:79, implementation at common_shims_slim.cpp:137-139. The ScalarType.h already had Byte = 0 in the enum (line 26) and in elementSize/toString/isIntegralType/isValidScalarType. This just adds the missing C API accessor. Minimal and correct.

Tests

  1. Kernel tests (test_tq4_sdpa.py, 802 lines): Comprehensive. Covers MHA, GQA (2x/4x/8x/MQA), head dims 64/128/256, decode and prefill, bool masks, causal masks (both is_causal=True and explicit lower-triangular), sparse masks, batch > 1, short KV, long KV (512/1024), all-masked NaN safety, output shape/dtype, validation error paths (H_Q % H_KV, L_Q != L_KV with causal, non-power-of-2 D, per-head mask, mask shape mismatch, float mask), TurboQuantKVCache + tq4_sdpa integration, decode accumulation, export through CudaPartitioner, and C++ runner e2e. Cosine threshold 0.95 is appropriate for 4-bit quantized attention.

  2. KV cache tests (test_turboquant_kv_cache.py, 262 lines): CPU-only. Roundtrip quality (>0.99 cosine), output shapes, bf16 input, state accumulation, head_dim=256, exhaustive nibble packing roundtrip (all 256 pairs), torch.export(strict=True) with dynamic shapes, exported state accumulation, codebook properties (sorted, symmetric, between-centroids boundaries, deterministic, varies with dim), rotation orthogonality/determinism/seed-variance, odd head_dim rejection.

  3. Integration tests (test_turboquant.py, 211 lines): CUDA required. Tiny model (no downloads), eager quality (>0.99 cosine vs baseline), decode quality, export-matches-eager (greedy token comparison), cache state matters (different prefills → different continuations), replacement count. Skip guard at setUp line 107-108.

  4. Test skip guards are correct: test_tq4_sdpa.py skips in setUpClass if no CUDA/bf16 (lines 146-148). test_turboquant_kv_cache.py has no optional deps (codebook is self-contained, scipy is only needed at solve time during cache init — if scipy is missing, test instantiation fails with a clear ImportError). test_turboquant.py skips in setUp if no CUDA.


CI Changes

  1. cuda.yml (lines 148-149): Runs test_turboquant.py and test_quantize_roundtrip.py in the same pytest command. The kernel tests (test_tq4_sdpa.py) are in backends/cuda/tests/ covered by the glob at line 146. Good.

  2. export_model_artifact.sh (lines 419-421): The export command does NOT pass --turboquant. The TQ4 path is not exercised in CI artifact generation. This means CI exports the baseline (non-TQ4) model. If TQ4 should be validated in CI, --turboquant needs to be added here. If the intent is to keep the default artifact as baseline, this is correct.

  3. test_model_e2e.sh (line 357): Changed the qwen3_5_moe prompt to 'What is the capital of France?' with --max_new_tokens 128. Previous reviews flagged this as weakening the e2e check (longer prompt, generic expected output). Looking at the file, I don't see an EXPECTED_OUTPUT set for qwen3_5_moe in this case block — the general expected output check lower in the script may not apply to this model. The prompt change is fine if the e2e test only checks for non-crash execution.


Items From Previous Reviews — Final Status

Item Status
Mask head-dim shape[1] == 1 validation Fixed_validate_tq4_mask line 614
dtype/device validation on Q/K/V Fixed_validate_tq4_inputs lines 573-576
GQA divisibility validation Fixed_validate_tq4_inputs lines 587-591
bits != 4 validation Fixedkv_cache.py:63-66
Dummy mask pointer Fixed — uses mask_ptr = 0 at line 704
Mask dim check Fixed_validate_tq4_mask line 620-629 checks B/Lq/Lkv
turboquant_vllm hard dependency Not an issue — codebook is self-contained
Test skip guards All handled

Remaining Actionable Items

  1. runtime_prefixes stale .mask entry (model.py:597): Should include .cache_positions instead of .mask since FullAttention now registers cache_positions instead of mask.

  2. _materialize_buffers docstring (export.py:315): Says "causal masks" but should say "cache_positions".

  3. Shape unpacking before rank check (tq4_sdpa.py:570-580): Minor — rank check should come before tuple unpacking to give the intended error message.

Items 1-2 are documentation/naming fixes. Item 3 is a minor validation ordering issue. None are blocking.


Summary

The PR is in good shape. Previous review feedback has been addressed — input validation now matches sdpa.py, the codebook is self-contained (no turboquant-vllm dependency), and test coverage is comprehensive. The kernel is correct, the integration is clean, and the architecture (three-layer separation: cache module / Triton kernel / model integration) is well-structured.

The remaining items are minor documentation fixes and a cosmetic validation ordering issue.


@mergennachin mergennachin temporarily deployed to upload-benchmark-results April 5, 2026 00:50 — with GitHub Actions Inactive
@mergennachin mergennachin merged commit afc9989 into main Apr 5, 2026
390 of 397 checks passed
@mergennachin mergennachin deleted the mergennachin/turboquant branch April 5, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants