TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE#18687
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18687
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 3 Unrelated FailuresAs of commit 94973a3 with merge base 300e368 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
a292381 to
969767b
Compare
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 6m 45s —— View job Review of PR #18687 — TurboQuant TQ4 KV Cache Compression
Overall this is a well-structured PR. The kernel, cache module, model integration, tests, and CI changes are all consistent. I have a few observations and suggestions below. Triton Kernel —
|
There was a problem hiding this comment.
Pull request overview
Adds TurboQuant TQ4 KV-cache compression and a fused Triton SDPA path to reduce KV-cache memory for Qwen 3.5 MoE, with export/e2e coverage and CI wiring.
Changes:
- Introduces
TurboQuantKVCachethat compresses K/V into nibble-packeduint8indices +fp32norms and exposes an update API for attention. - Adds a fused Triton op
triton::tq4_sdpathat performs inline K/V decompression per tile during attention. - Integrates a
--turboquantexport flag + Qwen 3.5 MoE model switch, plus CI/test updates and an AOTI uint8 dtype shim.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/modules/turboquant/kv_cache.py | Implements KV cache compression/decompression and stateful updates for export. |
| extension/llm/modules/turboquant/init.py | Exposes the TurboQuant module API. |
| extension/llm/modules/test/test_turboquant_kv_cache.py | CPU unit tests for compress/decompress, packing, and torch.export. |
| examples/models/qwen3_5_moe/test_turboquant.py | CUDA E2E test validating model quality + export behavior with TurboQuant enabled. |
| examples/models/qwen3_5_moe/model.py | Adds a TurboQuant attention path that calls torch.ops.triton.tq4_sdpa. |
| examples/models/qwen3_5_moe/export.py | Adds --turboquant and replaces full-attention KV caches with TurboQuantKVCache. |
| backends/cuda/triton/kernels/tq4_sdpa.py | New fused Triton SDPA kernel/op for attention over compressed KV cache. |
| backends/cuda/triton/kernels/init.py | Registers the new Triton op for import side-effects. |
| backends/cuda/tests/test_tq4_sdpa.py | CUDA kernel correctness + export/runner coverage for tq4_sdpa. |
| backends/aoti/common_shims_slim.h | Adds missing aoti_torch_dtype_uint8() declaration. |
| backends/aoti/common_shims_slim.cpp | Implements aoti_torch_dtype_uint8() to fix runner symbol resolution. |
| .github/workflows/cuda.yml | Installs turboquant-vllm and runs the new TurboQuant tests in CUDA CI. |
| .ci/scripts/export_model_artifact.sh | Ensures Qwen 3.5 MoE export installs turboquant-vllm and enables --turboquant. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from executorch.extension.llm.modules.turboquant import ( | ||
| TurboQuantKVCache, | ||
| replace_kv_cache_with_turboquant, | ||
| ) | ||
|
|
||
| # After model construction, before torch.export: | ||
| replace_kv_cache_with_turboquant(model, kv_cache_class=KVCache) |
There was a problem hiding this comment.
The module docstring references replace_kv_cache_with_turboquant, but that symbol isn't implemented/exported anywhere in the repo. This makes the usage example incorrect and will confuse users. Either add the helper function and export it from extension/llm/modules/turboquant/__init__.py, or update the docstring to reflect the actual integration path (e.g., _apply_turboquant in the Qwen export script).
| from executorch.extension.llm.modules.turboquant import ( | |
| TurboQuantKVCache, | |
| replace_kv_cache_with_turboquant, | |
| ) | |
| # After model construction, before torch.export: | |
| replace_kv_cache_with_turboquant(model, kv_cache_class=KVCache) | |
| from executorch.extension.llm.modules.turboquant.kv_cache import ( | |
| TurboQuantKVCache, | |
| ) | |
| # Instantiate and wire this module into the model wherever a KV cache | |
| # implementation is selected before torch.export. | |
| kv_cache = TurboQuantKVCache( | |
| n_heads=n_heads, | |
| head_dim=head_dim, | |
| max_seq_len=max_seq_len, | |
| ) |
| """ | ||
| B, H_Q, N_Q, D = query.shape | ||
| _, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
There was a problem hiding this comment.
tq4_sdpa assumes GQA-style head mapping with H_Q // H_KV and will break (division by zero or incorrect mapping) when H_Q < H_KV or H_Q % H_KV != 0. Please validate these constraints in the Python wrapper (and ideally assert k_packed.size(-1) * 2 == D as well) to fail fast with a clear error.
| if H_KV <= 0: | |
| raise ValueError( | |
| f"tq4_sdpa requires k_packed to have a positive KV head dimension, got H_KV={H_KV}." | |
| ) | |
| if H_Q < H_KV: | |
| raise ValueError( | |
| f"tq4_sdpa requires GQA-compatible head mapping with H_Q >= H_KV, got H_Q={H_Q}, H_KV={H_KV}." | |
| ) | |
| if H_Q % H_KV != 0: | |
| raise ValueError( | |
| f"tq4_sdpa requires GQA-compatible head mapping with H_Q % H_KV == 0, got H_Q={H_Q}, H_KV={H_KV}." | |
| ) | |
| if k_packed.size(-1) * 2 != D: | |
| raise ValueError( | |
| f"tq4_sdpa requires k_packed.size(-1) * 2 == query.size(-1), got " | |
| f"k_packed.size(-1)={k_packed.size(-1)} and D={D}." | |
| ) | |
| if v_packed.size(-1) != HALF_D: | |
| raise ValueError( | |
| f"tq4_sdpa requires v_packed.size(-1) == k_packed.size(-1), got " | |
| f"v_packed.size(-1)={v_packed.size(-1)} and k_packed.size(-1)={HALF_D}." | |
| ) |
| # Masking: use explicit mask when provided, never auto-causal. | ||
| # The caller (FullAttention) always provides an explicit bool mask | ||
| # that handles both prefill (lower-triangular) and decode (row mask). | ||
| HAS_MASK = attn_mask is not None | ||
| is_causal = False | ||
| if HAS_MASK: | ||
| Mask_ptr = attn_mask | ||
| stride_mb = attn_mask.stride(0) | ||
| stride_mq = attn_mask.stride(2) | ||
| stride_mk = attn_mask.stride(3) | ||
| else: | ||
| Mask_ptr = k_packed # dummy, won't be accessed | ||
| stride_mb = 0 | ||
| stride_mq = 0 | ||
| stride_mk = 0 |
There was a problem hiding this comment.
The Triton kernel ignores the mask's head dimension (it only offsets by batch), but tq4_sdpa accepts any attn_mask. If a caller passes a per-head mask (shape [B, H_Q, L_Q, L_KV]), the kernel will effectively read only the first head's mask. Consider validating attn_mask.size(1) == 1 (or extending the kernel to handle per-head masks) to avoid silent misbehavior.
| @classmethod | ||
| def setUpClass(cls): | ||
| _skip_if_no_cuda() | ||
| cls.tq4_sdpa = _import_tq4_sdpa() | ||
|
|
||
| def _run_test(self, B, H_q, H_kv, Lq, Lk, D, attn_mask=None, min_cosine=0.95): | ||
| torch.manual_seed(42) | ||
| centroids, boundaries, rotation = _make_codebook_and_rotation(D) | ||
| centroids = centroids.cuda() | ||
| boundaries = boundaries.cuda() | ||
| rotation = rotation.cuda() | ||
|
|
There was a problem hiding this comment.
These CUDA kernel tests rely on turboquant-vllm for reference codebook/rotation generation. If the dependency isn't installed, the tests will error rather than skip. Please add an ImportError guard (e.g., in setUpClass) that skips the suite when turboquant_vllm is unavailable, consistent with other optional-dependency tests in the repo.
969767b to
7fb707a
Compare
|
@claude Look at the revision based on previous feedback. Also review the PR more holistically |
|
Claude finished @mergennachin's task in 8m 19s —— View job Review of PR #18687 — TurboQuant TQ4 KV Cache Compression (revision check + holistic review)
Status of Previous Review FeedbackThere is only one commit (
Holistic ReviewTaking a step back from individual line items, this is my assessment of the PR as a complete contribution. Architecture & DesignThe PR adds TQ4 KV cache compression as a composable, opt-in feature. The design has three well-separated layers:
This layering is good. Each component is testable in isolation, and the integration surface in the model is small (one The Kernel CorrectnessThe Triton kernel (
This correctly inverts the packing The online softmax follows the standard two-pass-free algorithm with One subtlety: the Test CoverageCoverage is comprehensive and well-structured:
All test files have proper skip guards for the optional Remaining Issues1. Unconditional 2. Mask head-dimension not validated — The kernel accesses the mask with 3. Input dtype/device validation — The existing 4. Pre/post rotation cost — The 5. What Looks Good
SummaryThis is a well-structured PR. The kernel is correct, test coverage is thorough, and the integration is clean. The main actionable item is the unconditional |
7fb707a to
2338313
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 14 out of 14 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
.ci/scripts/export_model_artifact.sh
Outdated
| # turboquant-vllm requires Python >=3.12; upgrade if needed | ||
| PYTHON_MINOR=$(python -c "import sys; print(sys.version_info.minor)") | ||
| if [ "$PYTHON_MINOR" -lt 12 ]; then | ||
| echo "Upgrading Python to 3.12 (turboquant-vllm requires >=3.12)" | ||
| conda install -y python=3.12 | ||
| fi | ||
| pip install safetensors huggingface_hub | ||
| pip install turboquant-vllm==1.4.0 | ||
| pip install -r examples/models/qwen3_5_moe/requirements.txt |
There was a problem hiding this comment.
Upgrading Python in-place (conda install -y python=3.12) inside the artifact export script can break the already-installed executorch package and any compiled extensions in the environment (new interpreter / site-packages), which would make the subsequent python -m executorch... invocations fail. Consider ensuring the job/environment starts on Python 3.12 (or creating/activating a dedicated conda env for Qwen3.5 MoE export) rather than mutating the interpreter mid-script.
2338313 to
aa9190d
Compare
aa9190d to
c17c366
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 14 out of 14 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| B, H_Q, N_Q, D = query.shape | ||
| _, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
||
| if H_Q % H_KV != 0: | ||
| raise RuntimeError( | ||
| f"H_Q must be a multiple of H_KV for GQA head mapping, " | ||
| f"got H_Q={H_Q}, H_KV={H_KV}" | ||
| ) | ||
| if HALF_D * 2 != D: | ||
| raise RuntimeError( | ||
| f"k_packed last dim ({HALF_D}) * 2 must equal query head_dim ({D})" | ||
| ) | ||
|
|
||
| sm_scale = 1.0 / math.sqrt(D) | ||
|
|
||
| # Reshape norms: [B, H, S, 1] -> [B, H, S] | ||
| k_n = k_norms.reshape(B, H_KV, N_KV).contiguous() | ||
| v_n = v_norms.reshape(B, H_KV, N_KV).contiguous() | ||
|
|
||
| # Pre-rotate Q: Q_rot = Q @ Pi^T | ||
| # O(B·H·L·D²) matmul in float32. For long-sequence prefill this cost | ||
| # may be non-trivial; a future optimization could fuse rotation into | ||
| # the kernel. | ||
| q_rot = torch.matmul(query.float(), rotation.T).to(query.dtype).contiguous() | ||
|
|
||
| out_rot = torch.empty_like(query) | ||
|
|
||
| HEAD_DIM_PAD = _next_pow2(D) | ||
| HALF_D_PAD = _next_pow2(HALF_D) | ||
|
|
||
| HAS_MASK = attn_mask is not None | ||
| if is_causal and N_Q != N_KV: | ||
| raise RuntimeError( | ||
| f"is_causal requires L_Q == L_KV, got L_Q={N_Q}, L_KV={N_KV}. " | ||
| "For decode (L_Q < L_KV), use an explicit bool mask instead." | ||
| ) | ||
| if HAS_MASK: | ||
| Mask_ptr = attn_mask | ||
| stride_mb = attn_mask.stride(0) | ||
| stride_mq = attn_mask.stride(2) | ||
| stride_mk = attn_mask.stride(3) | ||
| else: | ||
| Mask_ptr = 0 # dummy, never dereferenced when HAS_MASK=False |
There was a problem hiding this comment.
tq4_sdpa does not validate key/value dtypes/shapes beyond head mapping. In particular, the kernel assumes k_packed/v_packed are uint8, k_norms/v_norms are float32 shaped [B,H,L,1] (or reshaped to [B,H,L]), centroids is length-16, rotation is [D,D], and attn_mask is bool [B,1,L_Q,L_KV]. Missing these checks can lead to hard-to-debug Triton crashes or silent misbehavior. Add explicit input validation and helpful error messages up front.
.github/workflows/cuda.yml
Outdated
| # Run TurboQuant KV cache module tests | ||
| python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py -v -o "addopts=" | ||
|
|
||
| # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache) | ||
| python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts=" |
There was a problem hiding this comment.
This workflow runs TurboQuant-related tests but does not install turboquant-vllm (commented out due to Python 3.10). As a result, these tests will be skipped, so the new feature/kernel isn’t actually exercised in CI. Consider adding a separate job using Python >=3.12 (or a dedicated container) to install turboquant-vllm and run these tests, or gate these test invocations on the dependency being present to avoid giving a false sense of coverage.
| # Run TurboQuant KV cache module tests | |
| python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py -v -o "addopts=" | |
| # Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache) | |
| python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts=" | |
| # Run Qwen 3.5 MoE quantize roundtrip tests | |
| python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py -v -o "addopts=" | |
| # Run TurboQuant-dependent tests only when the dependency is available | |
| if python -c "import importlib.util, sys; sys.exit(0 if importlib.util.find_spec('turboquant_vllm') else 1)"; then | |
| python -m pytest extension/llm/modules/test/test_turboquant_kv_cache.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts=" | |
| else | |
| echo "Skipping TurboQuant-specific tests because turboquant-vllm is not installed in this job." | |
| fi |
.ci/scripts/export_model_artifact.sh
Outdated
| # turboquant-vllm requires python 3.12 or above, our CI is still on 3.10 | ||
| # pip install turboquant-vllm==1.4.0 |
There was a problem hiding this comment.
turboquant-vllm is not installed here (commented out), which means the Qwen3.5 MoE export artifact path cannot exercise the TurboQuant codepath during CI/artifact generation. If TurboQuant support is expected to be validated in this flow, add a Python >=3.12 environment (or conditional installation) so the dependency can be installed and the export can cover the turboquant flag.
| # turboquant-vllm requires python 3.12 or above, our CI is still on 3.10 | |
| # pip install turboquant-vllm==1.4.0 | |
| if python -c 'import sys; raise SystemExit(0 if sys.version_info >= (3, 12) else 1)'; then | |
| pip install turboquant-vllm==1.4.0 | |
| else | |
| echo "Skipping turboquant-vllm installation: Python 3.12+ is required." | |
| fi |
c17c366 to
f9ade5f
Compare
f9ade5f to
900d589
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| B, H_Q, N_Q, D = query.shape | ||
| _, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
||
| if H_Q % H_KV != 0: | ||
| raise RuntimeError( | ||
| f"H_Q must be a multiple of H_KV for GQA head mapping, " | ||
| f"got H_Q={H_Q}, H_KV={H_KV}" | ||
| ) | ||
| if HALF_D * 2 != D: | ||
| raise RuntimeError( | ||
| f"k_packed last dim ({HALF_D}) * 2 must equal query head_dim ({D})" | ||
| ) | ||
|
|
There was a problem hiding this comment.
tq4_sdpa currently relies on implicit assumptions about tensor dtypes/shapes (e.g., k_packed/v_packed uint8, k_norms/v_norms fp32 with last dim 1, centroids length 16, attn_mask bool with shape [B,1,Lq,Lkv], all CUDA and contiguous). Other Triton ops in this repo validate inputs up front (e.g., sdpa() calls _validate_sdpa_inputs in backends/cuda/triton/kernels/sdpa.py:936+). Adding similar validation here would prevent silent wrong results or hard-to-diagnose Triton failures when called with unsupported inputs.
900d589 to
d949bf9
Compare
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 2m 45s —— View job Review of PR #18687 — TurboQuant TQ4 KV Cache Compression
Well-structured PR. The kernel, cache module, codebook solver, model integration, and tests are all consistent and correct. The previous two reviews covered the landscape well; this review focuses on confirming the current state and flagging remaining items. Triton Kernel —
|
| Item | Status |
|---|---|
Unconditional turboquant_vllm import |
Not an issue — codebook.py is self-contained; turboquant_vllm is not imported anywhere in production code |
| Mask head-dim validation | Still missing — see item 1 above |
| dtype/device validation | Still missing — see item 2 above |
bits != 4 validation |
Already handled (kv_cache.py:63-66) |
| GQA validation | Already handled (tq4_sdpa.py:606-613) |
| Test skip guards | Already handled in all test files |
| Dummy mask pointer | Fixed — uses Mask_ptr = 0 (tq4_sdpa.py:656) |
Summary
The PR is correct and well-organized. The codebook module being self-contained (no turboquant-vllm dependency) is a significant improvement over what the previous reviews assumed — the only optional dependency is scipy, imported lazily with a clear error.
Two actionable items:
- Add
attn_mask.shape[1] == 1validation intq4_sdpato matchsdpa.py, preventing silent per-head mask bugs. - Add dtype/device validation on input tensors to match
sdpa.py's_validate_sdpa_inputspattern.
Everything else is clean.
|
@claude How would you compare against other repos like https://github.com/cksac/turboquant-model, https://github.com/0xSero/turboquant or https://github.com/tonbistudio/turboquant-pytorch |
|
Claude finished @mergennachin's task in 4m 42s —— View job Comparing TurboQuant implementations
I was unable to access those three repositories (sandbox network restrictions blocked external GitHub API calls and web fetches), so I can't confirm whether they exist or what they contain. However, I'd flag caution: none of those three repositories are referenced anywhere in the TurboQuant paper (arXiv 2504.19874), in this PR, or in the known legitimate implementation. The only external implementation referenced by this PR's code is
Name-squatting on popular paper names is a known phenomenon on GitHub. The TurboQuant paper is recent (ICLR 2026), and repositories with generic names like What this PR actually implements vs. what a legitimate TurboQuant implementation should containFor reference, this PR's implementation includes the core components described in the paper:
A legitimate third-party implementation should have most of these components and should produce numerically consistent results with the paper's algorithm. Repos that only contain packaging metadata, vague READMEs, or code unrelated to vector quantization of KV caches are unlikely to be genuine implementations. |
d949bf9 to
613c0cb
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 17 out of 17 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _validate_tq4_inputs(query, k_packed, v_packed): | ||
| """Validate tensor shapes, dtypes, and device for tq4_sdpa.""" | ||
| B, H_Q, N_Q, D = query.shape | ||
| B_kp, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
||
| if not query.is_cuda: | ||
| raise RuntimeError("query must be a CUDA tensor") | ||
| if query.dtype != torch.bfloat16: | ||
| raise RuntimeError(f"query must be bfloat16, got {query.dtype}") | ||
| if query.dim() != 4: | ||
| raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D") | ||
| if k_packed.dim() != 4 or v_packed.dim() != 4: | ||
| raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]") |
There was a problem hiding this comment.
The function unpacks query.shape and k_packed.shape before verifying dim() == 4. If a caller passes a tensor with the wrong rank, this will raise a ValueError during tuple unpacking instead of the intended RuntimeError. Move the dim() checks (for query, k_packed, v_packed) before any shape unpacking so invalid inputs fail with the expected error message.
| _validate_tq4_inputs(query, k_packed, v_packed) | ||
|
|
||
| B, H_Q, N_Q, D = query.shape | ||
| _, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
||
| _validate_tq4_mask(attn_mask, B, N_Q, N_KV) | ||
|
|
||
| sm_scale = 1.0 / math.sqrt(D) | ||
| num_groups = H_Q // H_KV | ||
|
|
||
| # Build [256] bf16 lookup tables from [16] centroids. | ||
| # In the export path, inductor fuses this into the compiled graph. | ||
| all_bytes = torch.arange(256, device=centroids.device) | ||
| lut_hi = centroids[(all_bytes >> 4).long()].to(query.dtype).contiguous() | ||
| lut_lo = centroids[(all_bytes & 0x0F).long()].to(query.dtype).contiguous() | ||
|
|
||
| # Reshape norms: [B, H, S, 1] -> [B, H, S] | ||
| k_n = k_norms.reshape(B, H_KV, N_KV).contiguous() | ||
| v_n = v_norms.reshape(B, H_KV, N_KV).contiguous() |
There was a problem hiding this comment.
tq4_sdpa assumes centroids, rotation, k_norms, and v_norms are CUDA tensors on the same device as query. Currently only query/k_packed/v_packed/attn_mask are validated. If (for example) centroids is on CPU, all_bytes/LUTs will be created on CPU and the Triton kernel launch will fail. Add explicit validation that centroids.is_cuda, rotation.is_cuda, k_norms.is_cuda, and v_norms.is_cuda (and ideally device matches query.device), plus basic shape checks for centroids ([16]) and rotation ([D, D]) to fail fast with a clear error.
| # Pre-rotate Q: Q_rot = Q @ Pi^T (bf16 — TQ4 error dominates) | ||
| q_rot = torch.matmul(query, rotation.T.to(query.dtype)).contiguous() | ||
|
|
||
| out_rot = torch.empty_like(query) | ||
|
|
||
| HAS_MASK = attn_mask is not None |
There was a problem hiding this comment.
This does rotation.T.to(query.dtype) on every call, which creates a new converted tensor each time (and similarly rotation.to(query.dtype) at return). Consider converting once per call (e.g., rot = rotation.to(query.dtype); rot_t = rot.T) and reusing those tensors for the pre/post matmuls to reduce overhead in decode loops.
| """KV cache with TQ4 compression. | ||
|
|
||
| Stores K/V as nibble-packed uint8 indices (2 indices per byte) plus | ||
| fp32 per-vector norms. The ``update()`` method compresses incoming |
There was a problem hiding this comment.
The docstring states norms are stored as fp32, but the implementation registers k_norms/v_norms as torch.bfloat16 and _compress() returns bf16 norms. Please update the docstring to match the actual dtype (bf16), or switch the buffers/returned norms to fp32 if fp32 is required for correctness.
| fp32 per-vector norms. The ``update()`` method compresses incoming | |
| bf16 per-vector norms. The ``update()`` method compresses incoming |
| def test_roundtrip_quality(self): | ||
| cache = TurboQuantKVCache(N_HEADS, HEAD_DIM, MAX_SEQ_LEN, BITS) | ||
| x = torch.randn(1, N_HEADS, 10, HEAD_DIM) |
There was a problem hiding this comment.
TurboQuantKVCache construction calls solve_lloyd_max(), which raises ImportError if SciPy is not installed. If the repo's default test environment doesn't include SciPy, these tests will fail even when TurboQuant is not used elsewhere. Consider skipping these tests when SciPy isn't available (e.g., via an import check / pytest importorskip equivalent for unittest), or providing a non-SciPy fallback path for codebook initialization in test mode.
.ci/scripts/test_model_e2e.sh
Outdated
| RUNNER_TARGET="qwen3_5_moe_runner" | ||
| RUNNER_PATH="qwen3_5_moe" | ||
| EXPECTED_OUTPUT="Paris" | ||
| EXPECTED_OUTPUT="the" |
There was a problem hiding this comment.
Changing EXPECTED_OUTPUT from a specific token ("Paris") to a very common token ("the") significantly weakens the e2e correctness check and increases the chance of false positives. Also, the longer prompt and --max_new_tokens 128 likely increase CI runtime and flakiness. Consider keeping a correctness-oriented prompt/expected output for CI (short + specific), and gating the longer perf-style prompt behind an opt-in flag/environment variable.
| EXPECTED_OUTPUT="the" | |
| EXPECTED_OUTPUT="Paris" |
.ci/scripts/test_model_e2e.sh
Outdated
| # pp512/tg128 matches llama-bench defaults for comparable perf numbers | ||
| PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)") | ||
| RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0" |
There was a problem hiding this comment.
Changing EXPECTED_OUTPUT from a specific token ("Paris") to a very common token ("the") significantly weakens the e2e correctness check and increases the chance of false positives. Also, the longer prompt and --max_new_tokens 128 likely increase CI runtime and flakiness. Consider keeping a correctness-oriented prompt/expected output for CI (short + specific), and gating the longer perf-style prompt behind an opt-in flag/environment variable.
| # pp512/tg128 matches llama-bench defaults for comparable perf numbers | |
| PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)") | |
| RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0" | |
| if [ "${QWEN3_5_MOE_PERF_PROMPT:-0}" = "1" ]; then | |
| # Opt-in perf-style prompt for local benchmarking/stress runs. | |
| # pp512/tg128 matches llama-bench defaults for comparable perf numbers. | |
| PROMPT_512=$(python3 -c "print('The history of artificial intelligence begins in antiquity, with myths, stories and rumors of artificial beings endowed with intelligence or consciousness by master craftsmen. The seeds of modern AI were planted by philosophers who attempted to describe the process of human thinking as the mechanical manipulation of symbols. This work culminated in the invention of the programmable digital computer, a machine based on the abstract essence of mathematical reasoning. ' * 6)") | |
| RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt '${PROMPT_512}' --max_new_tokens 128 --temperature 0" | |
| else | |
| # Keep CI e2e focused on correctness: short prompt, specific expected token, | |
| # and small generation length to reduce runtime and flakiness. | |
| EXPECTED_OUTPUT="Paris" | |
| RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --prompt 'The capital of France is' --max_new_tokens 16 --temperature 0" | |
| fi |
.ci/scripts/export_model_artifact.sh
Outdated
| TORCHINDUCTOR_CACHE_DIR="$INDUCTOR_CACHE" \ | ||
| python -m executorch.examples.models.qwen3_5_moe.export \ | ||
| --prequantized "$LOCAL_MODEL_DIR" \ | ||
| --output-dir "${OUTPUT_DIR}" | ||
| --output-dir "${OUTPUT_DIR}" \ | ||
| --turboquant |
There was a problem hiding this comment.
The PR introduces --turboquant as an optional flag, but the artifact export script now always enables it for Qwen3.5 MoE. If the intention is to keep baseline artifacts unchanged by default, consider producing a separate TurboQuant artifact (or gating this behind an env var) to avoid silently changing the default exported model behavior/perf characteristics.
613c0cb to
f325eee
Compare
Add TurboQuant (arXiv 2504.19874) KV cache compression to the CUDA backend, reducing KV cache memory 3.8x by storing nibble-packed uint8 indices + bf16 norms instead of bf16 tensors. A fused Triton SDPA kernel decompresses K/V per-tile in the attention inner loop so the full cache is never materialized. Components: - backends/cuda/triton/kernels/tq4_sdpa.py: Fused TQ4 Flash Attention kernel with Pack GQA optimization (adapted from sdpa.py structure), precomputed [256]-entry bf16 LUT for zero-bit-op centroid gather, and norm factoring (norms multiplied on [M,N] QK/P matrices instead of [N,D] K/V tiles). NaN-safe softmax guards for sparse masks. Registered as @triton_op for torch.export + CUDA backend lowering. - extension/llm/modules/turboquant/: TurboQuantKVCache nn.Module with bf16 compression path and self-contained Lloyd-Max codebook solver (no external dependencies; scipy used lazily for codebook init only). - examples/models/qwen3_5_moe/: --turboquant flag in export.py, branch in FullAttention.forward() between triton::sdpa and tq4_sdpa. Both paths use filled-length KV cache slicing (PAD(filled, 256)) to skip empty positions, enabling 200K context without O(n^2) mask or linear decode slowdown. O(n) causal mask computed on the fly from cache_positions, replacing the O(n^2) pre-materialized mask. - backends/cuda/triton/kernels/sdpa.py: guard_or_false for L_kv_k != L_kv_v check, enabling export with data-dependent KV cache slicing. - backends/aoti/: Added aoti_torch_dtype_uint8 shim and Byte ScalarType to slim headers (required for uint8 KV cache tensors in C++ runtime). Performance (A100, Qwen 3.5 MoE, B=1, GQA 16:2, D=256, seq=4096): TQ4 SDPA kernel: 0.66ms (was 3.74ms before optimizations) Baseline bf16: 0.45ms (1.5x overhead for 3.8x memory savings) Full AOTI path: 0.79ms (inductor fuses compress ops) E2e decode: 59.8 tok/s (vs 70.6 baseline, 20.6 original)
f325eee to
94973a3
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _validate_tq4_mask(attn_mask, B, N_Q, N_KV): | ||
| """Validate attention mask for tq4_sdpa.""" | ||
| if attn_mask is None: | ||
| return | ||
| if attn_mask.dtype != torch.bool: | ||
| raise RuntimeError( | ||
| f"attn_mask must be bool, got {attn_mask.dtype}. " | ||
| "Additive float masks are not supported." | ||
| ) | ||
| if not attn_mask.is_cuda: | ||
| raise RuntimeError("attn_mask must be a CUDA tensor") | ||
| if attn_mask.shape[1] != 1: | ||
| raise RuntimeError( | ||
| f"attn_mask head dimension must be 1 (broadcast over heads); " | ||
| f"per-head masks are not supported. " | ||
| f"Got attn_mask.shape={attn_mask.shape}" | ||
| ) |
There was a problem hiding this comment.
_validate_tq4_mask() indexes attn_mask.shape[1], [2], and [3] without first validating attn_mask.dim() == 4. For masks with the wrong rank, this will raise an IndexError instead of the intended RuntimeError. Add an explicit rank check early (e.g., require 4D) before accessing shape indices.
| _validate_tq4_inputs(query, k_packed, v_packed) | ||
|
|
||
| B, H_Q, N_Q, D = query.shape | ||
| _, H_KV, N_KV, HALF_D = k_packed.shape | ||
|
|
||
| _validate_tq4_mask(attn_mask, B, N_Q, N_KV) | ||
|
|
||
| sm_scale = 1.0 / math.sqrt(D) | ||
| num_groups = H_Q // H_KV | ||
|
|
||
| # Build [256] bf16 lookup tables from [16] centroids. | ||
| # In the export path, inductor fuses this into the compiled graph. | ||
| all_bytes = torch.arange(256, device=centroids.device) | ||
| lut_hi = centroids[(all_bytes >> 4).long()].to(query.dtype).contiguous() | ||
| lut_lo = centroids[(all_bytes & 0x0F).long()].to(query.dtype).contiguous() | ||
|
|
||
| # Reshape norms: [B, H, S, 1] -> [B, H, S] | ||
| k_n = k_norms.reshape(B, H_KV, N_KV).contiguous() | ||
| v_n = v_norms.reshape(B, H_KV, N_KV).contiguous() | ||
|
|
||
| # Pre-rotate Q: Q_rot = Q @ Pi^T (bf16 — TQ4 error dominates) | ||
| q_rot = torch.matmul(query, rotation.T.to(query.dtype)).contiguous() | ||
|
|
There was a problem hiding this comment.
tq4_sdpa() implicitly assumes centroids, rotation, k_norms, and v_norms are CUDA tensors on the same device as query (e.g., it builds LUTs on centroids.device and matmuls with rotation). If any of these are on CPU or a different CUDA device, the function will fail with less clear downstream errors. Consider extending validation to require these tensors are CUDA, on query.device, and have expected dtypes/shapes (e.g., centroids length 16, rotation (D,D), norms shape [B,H,L] or [B,H,L,1]).
| try: | ||
| from scipy import integrate | ||
| from scipy.stats import norm | ||
| except ImportError: | ||
| raise ImportError( | ||
| "scipy is required for TurboQuant codebook computation. " | ||
| "Install it with: pip install scipy" | ||
| ) |
There was a problem hiding this comment.
solve_lloyd_max() hard-requires SciPy (integrate + stats.norm) and raises ImportError if it's missing, but this repository's Python requirements files don't list SciPy. As-is, instantiating TurboQuantKVCache (and tests that depend on it) will fail in environments without SciPy. Consider providing a SciPy-free fallback (e.g., precomputed centroids/boundaries for common head_dims) and/or skipping TurboQuant tests when SciPy isn't available, plus documenting/adding the optional dependency where appropriate.
| # Recompute cache_positions for full attention layers | ||
| for layer in model.layers: | ||
| if hasattr(layer.attn, "mask"): | ||
| mask = torch.tril( | ||
| torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool) | ||
| if hasattr(layer.attn, "cache_positions"): | ||
| layer.attn.cache_positions = torch.arange( | ||
| config.max_seq_len, dtype=torch.long | ||
| ) |
There was a problem hiding this comment.
_materialize_buffers()'s docstring still says it "recomputes RoPE inv_freq and causal masks", but the implementation now recomputes cache_positions (and no longer rebuilds a mask buffer). Please update the docstring (above) to match the new behavior so readers don't expect mask to exist.
| self.kv_cache = KVCache(self.n_kv_heads, self.head_dim, config.max_seq_len) | ||
| self.turboquant = False | ||
|
|
||
| mask = torch.tril( | ||
| torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool) | ||
| self.register_buffer( | ||
| "cache_positions", | ||
| torch.arange(config.max_seq_len, dtype=torch.long), | ||
| ) |
There was a problem hiding this comment.
FullAttention no longer registers a mask buffer, but elsewhere in this file (e.g., from_hf_checkpoint()'s runtime_prefixes) still treats .mask as an expected missing runtime key. With this change, .cache_positions is likely to be reported as an unexpected missing key (warning), potentially hiding real mismatches. Update the missing-key allowlist to include .cache_positions instead of .mask.
| _, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape | ||
| layer.attn.kv_cache = TurboQuantKVCache( | ||
| n_heads, |
There was a problem hiding this comment.
In _apply_turboquant(), _, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape uses n_heads to mean the number of KV heads (the second dimension of k_cache). Renaming this to something like n_kv_heads would make the intent clearer and avoid confusion with the model's query head count.
| _, n_heads, max_seq_len, head_dim = old_cache.k_cache.shape | |
| layer.attn.kv_cache = TurboQuantKVCache( | |
| n_heads, | |
| _, n_kv_heads, max_seq_len, head_dim = old_cache.k_cache.shape | |
| layer.attn.kv_cache = TurboQuantKVCache( | |
| n_kv_heads, |
| centroids, boundaries = solve_lloyd_max(head_dim, bits) | ||
| rotation = generate_rotation_matrix(head_dim, seed=seed) | ||
|
|
||
| self.register_buffer("centroids", centroids) | ||
| self.register_buffer("boundaries", boundaries.to(torch.bfloat16)) | ||
| self.register_buffer("rotation", rotation) | ||
| self.register_buffer("rotation_T", rotation.T.to(torch.bfloat16).contiguous()) | ||
|
|
There was a problem hiding this comment.
TurboQuantKVCache.init always calls generate_rotation_matrix(head_dim, seed), and _apply_turboquant() uses the default seed for every layer. That means Qwen 3.5 MoE will recompute the same (D×D) QR-based rotation matrix many times during export/init. Consider caching generate_rotation_matrix() (similar to solve_lloyd_max) or allowing callers to pass a precomputed rotation to avoid redundant O(D^3) work.
| if not query.is_cuda: | ||
| raise RuntimeError("query must be a CUDA tensor") | ||
| if query.dtype != torch.bfloat16: | ||
| raise RuntimeError(f"query must be bfloat16, got {query.dtype}") | ||
| if query.dim() != 4: | ||
| raise RuntimeError(f"query must be 4D [B, H, L, D], got {query.dim()}D") | ||
| if k_packed.dim() != 4 or v_packed.dim() != 4: | ||
| raise RuntimeError("k_packed and v_packed must be 4D [B, H, L, D//2]") | ||
| if k_packed.dtype != torch.uint8 or v_packed.dtype != torch.uint8: | ||
| raise RuntimeError("k_packed and v_packed must be uint8") | ||
| if B_kp != B: | ||
| raise RuntimeError( | ||
| f"Batch dim mismatch: query has B={B}, k_packed has B={B_kp}" | ||
| ) | ||
| if H_Q % H_KV != 0: | ||
| raise RuntimeError( | ||
| f"H_Q must be a multiple of H_KV for GQA head mapping, " | ||
| f"got H_Q={H_Q}, H_KV={H_KV}" | ||
| ) | ||
| if HALF_D * 2 != D: | ||
| raise RuntimeError( |
There was a problem hiding this comment.
_validate_tq4_inputs() doesn't verify that k_packed/v_packed are CUDA tensors (only query is checked) and doesn't check that v_packed matches k_packed's shape (B/H/L/D//2). Passing CPU packed tensors or mismatched K/V shapes will lead to confusing kernel-launch or out-of-bounds errors. Add explicit device + shape compatibility checks for K/V (and ideally norms) here.
|
@claude Review this PR |
|
Claude finished @mergennachin's task in 4m 0s —— View job Review of PR #18687 — TurboQuant TQ4 KV Cache Compression
No new commits since the last review (still single commit Triton Kernel —
|
| Item | Status |
|---|---|
Mask head-dim shape[1] == 1 validation |
Fixed — _validate_tq4_mask line 614 |
| dtype/device validation on Q/K/V | Fixed — _validate_tq4_inputs lines 573-576 |
| GQA divisibility validation | Fixed — _validate_tq4_inputs lines 587-591 |
bits != 4 validation |
Fixed — kv_cache.py:63-66 |
| Dummy mask pointer | Fixed — uses mask_ptr = 0 at line 704 |
| Mask dim check | Fixed — _validate_tq4_mask line 620-629 checks B/Lq/Lkv |
turboquant_vllm hard dependency |
Not an issue — codebook is self-contained |
| Test skip guards | All handled |
Remaining Actionable Items
-
runtime_prefixesstale.maskentry (model.py:597): Should include.cache_positionsinstead of.masksinceFullAttentionnow registerscache_positionsinstead ofmask. -
_materialize_buffersdocstring (export.py:315): Says "causal masks" but should say "cache_positions". -
Shape unpacking before rank check (
tq4_sdpa.py:570-580): Minor — rank check should come before tuple unpacking to give the intended error message.
Items 1-2 are documentation/naming fixes. Item 3 is a minor validation ordering issue. None are blocking.
Summary
The PR is in good shape. Previous review feedback has been addressed — input validation now matches sdpa.py, the codebook is self-contained (no turboquant-vllm dependency), and test coverage is comprehensive. The kernel is correct, the integration is clean, and the architecture (three-layer separation: cache module / Triton kernel / model integration) is well-structured.
The remaining items are minor documentation fixes and a cosmetic validation ordering issue.
TurboQuant TQ4 KV cache compression for Qwen 3.5 MoE
Add TurboQuant (arXiv 2504.19874) KV cache compression to the CUDA
backend, reducing KV cache memory 3.8x by storing nibble-packed uint8
indices + bf16 norms instead of bf16 tensors. A fused Triton SDPA
kernel decompresses K/V per-tile in the attention inner loop so the
full cache is never materialized.
Components:
backends/cuda/triton/kernels/tq4_sdpa.py: Fused TQ4 Flash Attention
kernel with Pack GQA optimization (adapted from sdpa.py structure),
precomputed [256]-entry bf16 LUT for zero-bit-op centroid gather,
and norm factoring (norms multiplied on [M,N] QK/P matrices instead
of [N,D] K/V tiles). NaN-safe softmax guards for sparse masks.
Registered as @triton_op for torch.export + CUDA backend lowering.
extension/llm/modules/turboquant/: TurboQuantKVCache nn.Module with
bf16 compression path and self-contained Lloyd-Max codebook solver
(no external dependencies; scipy used lazily for codebook init only).
examples/models/qwen3_5_moe/: --turboquant flag in export.py, branch
in FullAttention.forward() between standard SDPA and tq4_sdpa.
backends/aoti/: Added aoti_torch_dtype_uint8 shim and Byte ScalarType
to slim headers (required for uint8 KV cache tensors in C++ runtime).
Performance (A100, Qwen 3.5 MoE, B=1, GQA 16:2, D=256, seq=4096):
TQ4 SDPA kernel: 0.66ms (was 3.74ms before optimizations)
Baseline bf16: 0.45ms (1.5x overhead for 3.8x memory savings)
Full AOTI path: 0.79ms (inductor fuses compress ops)
The full Qwen 3.5 MoE model with TurboQuant KV cache compression runs at 75% (78->60)of baseline decode speed with 3.8x memory savings (relative to normal KVcache). End-to-end validated through the C++ runner. But full attention kvcache is small part of Qwen3.5 MoE because 3/4 of it is recurrent states. In 200K context length, turboquant saves about 3GB (4GB -> 1GB) of kvcache.