diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index ed57ea3fc7..89097fd32c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -1179,6 +1179,17 @@ def sync_act_scale_across_dp(module, data_parallel_group): module.parallel_state.data_parallel_group, ) + # Disable AWQ search for uncalibrated experts (num_cache_steps == 0) to + # prevent get_scale() crash on float act_scale. Max calibration and neutral + # pre_quant_scale are applied in the postprocessing loop below. + for name, module in model.named_modules(): + if ( + is_quantized_linear(module) + and hasattr(module, "awq_lite") + and module.awq_lite.num_cache_steps == 0 + ): + module.awq_lite.is_enabled = False + AWQLiteHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") with torch.no_grad(): @@ -1212,16 +1223,33 @@ def postprocess(module, name): for name, module in model.named_modules(): if hasattr(module, "awq_lite"): if module.awq_lite.num_cache_steps == 0: - module.awq_lite.is_enabled = False - elif module.awq_lite.num_search_steps == 0: - module.awq_lite.is_enabled = False - warnings.warn( - "awq_lite: Calling `forward_loop(model)` the second time did not forward data through the" - f" {name}. Please provide a valid `forward_loop` function that can be used to" - " forward data through the model many times." + # Uncalibrated expert: max calibrate weights and apply neutral + # (all-ones) pre_quant_scale for export consistency. + # NOTE: ones_scale must be registered OUTSIDE enable_weight_access_and_writeback + # because HF accelerate post_forward drops newly-registered submodule buffers. + with enable_weight_access_and_writeback(module, model, name_to_module): + max_calibrate(module, lambda module: module.weight_quantizer(module.weight)) + w_shape, w_dtype, w_device = ( + module.weight.shape[1], + module.weight.dtype, + module.weight.device, + ) + module.input_quantizer._enable_pre_quant_scale = True + module.input_quantizer.pre_quant_scale = torch.ones( + w_shape, + dtype=w_dtype, + device=w_device, ) - with enable_weight_access_and_writeback(module, model, name_to_module): - postprocess(module, name) + else: + if module.awq_lite.num_search_steps == 0: + module.awq_lite.is_enabled = False + warnings.warn( + "awq_lite: Calling `forward_loop(model)` the second time did not forward" + f" data through the {name}. Please provide a valid `forward_loop` function" + " that can be used to forward data through the model many times." + ) + with enable_weight_access_and_writeback(module, model, name_to_module): + postprocess(module, name) module.awq_lite.cleanup() if not debug: