[NVBug: 6000530] Fix AWQ crash for uncalibrated MoE experts#1142
[NVBug: 6000530] Fix AWQ crash for uncalibrated MoE experts#1142
Conversation
…rch phase When moe_calib_experts_ratio < 1.0, some MoE experts may never receive tokens during the AWQ cache phase, leaving act_scale as a Python float (0.0) instead of a tensor. During the search phase, these uncalibrated experts crash in get_scale() on float.pow(). Fix by disabling AWQ for experts with num_cache_steps == 0 before the search phase begins, so they gracefully fall back to max calibration. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughPre-search now skips AWQ parameter search for quantized linear modules with Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1142 +/- ##
==========================================
+ Coverage 70.19% 70.20% +0.01%
==========================================
Files 230 230
Lines 26073 26080 +7
==========================================
+ Hits 18302 18310 +8
+ Misses 7771 7770 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1182-1200: The loop that handles uncalibrated experts leaves input
quantization disabled because setup() may have turned off module.input_quantizer
but postprocess is skipped when module.awq_lite.num_cache_steps == 0; modify the
block handling those modules (the for loop iterating model.named_modules(), the
branch checking is_quantized_linear(module) && hasattr(module, "awq_lite") &&
module.awq_lite.num_cache_steps == 0) to re-enable the input_quantizer state it
originally had: after setting module.input_quantizer.pre_quant_scale and before
disabling module.awq_lite.is_enabled, restore
module.input_quantizer._enable_pre_quant_scale (or call the appropriate
re-enable API on input_quantizer) to the value it had prior to setup() so
uncalibrated experts that started with input quantization enabled end up
re-enabled.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6b32c7d9-8fae-4245-9dc9-bf46dfda9d9f
📒 Files selected for processing (1)
modelopt/torch/quantization/model_calib.py
| # Handle uncalibrated experts (e.g. when moe_calib_experts_ratio < 1.0, | ||
| # some experts may never receive tokens during the cache phase, leaving act_scale | ||
| # as a Python float instead of a tensor, which would crash in get_scale()). | ||
| # We fully handle them here: max calibrate weights, apply a neutral (all-ones) | ||
| # pre_quant_scale for export consistency, and disable AWQ search. | ||
| for name, module in model.named_modules(): | ||
| if ( | ||
| is_quantized_linear(module) | ||
| and hasattr(module, "awq_lite") | ||
| and module.awq_lite.num_cache_steps == 0 | ||
| ): | ||
| with enable_weight_access_and_writeback(module, model, name_to_module): | ||
| max_calibrate(module, lambda module: module.weight_quantizer(module.weight)) | ||
| ones_scale = torch.ones( | ||
| module.weight.shape[1], dtype=module.weight.dtype, device=module.weight.device | ||
| ) | ||
| module.input_quantizer._enable_pre_quant_scale = True | ||
| module.input_quantizer.pre_quant_scale = ones_scale | ||
| module.awq_lite.is_enabled = False |
There was a problem hiding this comment.
Missing input_quantizer re-enable for uncalibrated experts.
When setup() runs, it disables input_quantizer if it was originally enabled. For modules with num_cache_steps == 0, postprocess is skipped (lines 1234-1236), so the input_quantizer is never re-enabled. This will leave input quantization disabled for uncalibrated experts that originally had it enabled.
🐛 Proposed fix to re-enable input_quantizer
module.input_quantizer._enable_pre_quant_scale = True
module.input_quantizer.pre_quant_scale = ones_scale
+ if module.awq_lite.is_input_quantized:
+ module.input_quantizer.enable()
module.awq_lite.is_enabled = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/quantization/model_calib.py` around lines 1182 - 1200, The
loop that handles uncalibrated experts leaves input quantization disabled
because setup() may have turned off module.input_quantizer but postprocess is
skipped when module.awq_lite.num_cache_steps == 0; modify the block handling
those modules (the for loop iterating model.named_modules(), the branch checking
is_quantized_linear(module) && hasattr(module, "awq_lite") &&
module.awq_lite.num_cache_steps == 0) to re-enable the input_quantizer state it
originally had: after setting module.input_quantizer.pre_quant_scale and before
disabling module.awq_lite.is_enabled, restore
module.input_quantizer._enable_pre_quant_scale (or call the appropriate
re-enable API on input_quantizer) to the value it had prior to setup() so
uncalibrated experts that started with input quantization enabled end up
re-enabled.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/quantization/model_calib.py (1)
1226-1242:⚠️ Potential issue | 🟠 MajorRestore
input_quantizerin the uncalibrated-expert fallback.
AWQLiteHelper.setup()disables the input quantizer at Lines 1009-1015, and the normal re-enable path lives inpostprocess()at Lines 1204-1215. Because this branch skipspostprocess(), experts that started with input quantization enabled silently stay disabled after AWQ completes.🐛 Proposed fix
module.input_quantizer.pre_quant_scale = torch.ones( w_shape, dtype=w_dtype, device=w_device, ) + if module.awq_lite.is_input_quantized: + module.input_quantizer.enable()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 1226 - 1242, This branch skips postprocess(), so restore the input quantizer exactly as postprocess() does: after the weight-calibration block (inside the uncalibrated-expert fallback), re-enable the module's input quantizer and set its pre-quant scale state by applying the same changes postprocess() applies — e.g. flip the input_quantizer enabled flag back on and set module.input_quantizer._enable_pre_quant_scale = True and module.input_quantizer.pre_quant_scale = torch.ones(...) (use w_shape/w_dtype/w_device), mirroring AWQLiteHelper.setup and postprocess() behavior so experts that started with input quantization enabled are re-enabled here as well.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1226-1242: This branch skips postprocess(), so restore the input
quantizer exactly as postprocess() does: after the weight-calibration block
(inside the uncalibrated-expert fallback), re-enable the module's input
quantizer and set its pre-quant scale state by applying the same changes
postprocess() applies — e.g. flip the input_quantizer enabled flag back on and
set module.input_quantizer._enable_pre_quant_scale = True and
module.input_quantizer.pre_quant_scale = torch.ones(...) (use
w_shape/w_dtype/w_device), mirroring AWQLiteHelper.setup and postprocess()
behavior so experts that started with input quantization enabled are re-enabled
here as well.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 1b2b5045-67b9-42d6-a698-d0441e711bdb
📒 Files selected for processing (1)
modelopt/torch/quantization/model_calib.py
Summary
AttributeError: 'float' object has no attribute 'pow'when running AWQ lite withmoe_calib_experts_ratio < 1.0on MoE models (e.g. Qwen3-30B-A3B).moe_calib_experts_ratio=0.5, some MoE experts receive zero tokens during the AWQ cache phase, leavingact_scaleas a Python float0.0instead of a tensor. This causes two failures:get_scale()becausefloat.pow()doesn't exist.pre_quant_scalebut uncalibrated ones don't, causingtorch.stack()to fail on mixedNone/tensor values inpreprocess_linear_fusion().num_cache_steps == 0) in two stages:is_enabled = False) to preventget_scale()crash on floatact_scale.pre_quant_scaleso export can stack scaling factors consistently across all experts. Thepre_quant_scalebuffer must be registered outsideenable_weight_access_and_writebackbecause HF accelerate'spost_forwardhook drops newly-registered submodule buffers.Test plan
Qwen/Qwen3-30B-A3B,--qformat int4_awq,--moe_calib_experts_ratio 0.5— verify no crash during calibration and export🤖 Generated with Claude Code