Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,17 @@ def sync_act_scale_across_dp(module, data_parallel_group):
module.parallel_state.data_parallel_group,
)

# Disable AWQ search for uncalibrated experts (num_cache_steps == 0) to
# prevent get_scale() crash on float act_scale. Max calibration and neutral
# pre_quant_scale are applied in the postprocessing loop below.
for name, module in model.named_modules():
if (
is_quantized_linear(module)
and hasattr(module, "awq_lite")
and module.awq_lite.num_cache_steps == 0
):
module.awq_lite.is_enabled = False

AWQLiteHelper.cache_mode = False
print_rank_0("awq_lite: Searching parameters...")
with torch.no_grad():
Expand Down Expand Up @@ -1212,16 +1223,33 @@ def postprocess(module, name):
for name, module in model.named_modules():
if hasattr(module, "awq_lite"):
if module.awq_lite.num_cache_steps == 0:
module.awq_lite.is_enabled = False
elif module.awq_lite.num_search_steps == 0:
module.awq_lite.is_enabled = False
warnings.warn(
"awq_lite: Calling `forward_loop(model)` the second time did not forward data through the"
f" {name}. Please provide a valid `forward_loop` function that can be used to"
" forward data through the model many times."
# Uncalibrated expert: max calibrate weights and apply neutral
# (all-ones) pre_quant_scale for export consistency.
# NOTE: ones_scale must be registered OUTSIDE enable_weight_access_and_writeback
# because HF accelerate post_forward drops newly-registered submodule buffers.
with enable_weight_access_and_writeback(module, model, name_to_module):
max_calibrate(module, lambda module: module.weight_quantizer(module.weight))
w_shape, w_dtype, w_device = (
module.weight.shape[1],
module.weight.dtype,
module.weight.device,
)
module.input_quantizer._enable_pre_quant_scale = True
module.input_quantizer.pre_quant_scale = torch.ones(
w_shape,
dtype=w_dtype,
device=w_device,
)
with enable_weight_access_and_writeback(module, model, name_to_module):
postprocess(module, name)
else:
if module.awq_lite.num_search_steps == 0:
module.awq_lite.is_enabled = False
warnings.warn(
"awq_lite: Calling `forward_loop(model)` the second time did not forward"
f" data through the {name}. Please provide a valid `forward_loop` function"
" that can be used to forward data through the model many times."
)
with enable_weight_access_and_writeback(module, model, name_to_module):
postprocess(module, name)

module.awq_lite.cleanup()
if not debug:
Expand Down
Loading