diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index c80bc8aaa4..ba44c7a619 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -118,46 +118,6 @@ def quantize_master_weights( else: use_fsdp_shard_model_weights = True - # Batch convert master_weights to model dtype for NVFP4 (single kernel instead of N kernels) - # Check if there are any NVFP4 weights - has_nvfp4 = any( - isinstance(w._get_quantizer(), NVFP4Quantizer) - for w in model_weights - if hasattr(w, "_get_quantizer") - ) - if has_nvfp4 and len(model_weights) > 0: - # Find target dtype from first NVFP4 weight - target_dtype = None - for w in model_weights: - if hasattr(w, "_get_quantizer") and isinstance(w._get_quantizer(), NVFP4Quantizer): - target_dtype = w.dtype - break - - if target_dtype is not None: - # Collect non-None master_weights and their indices - non_none_indices = [] - non_none_weights = [] - sizes = [] - for i, mw in enumerate(master_weights): - if mw is not None: - non_none_indices.append(i) - non_none_weights.append(mw.view(-1)) - sizes.append(mw.numel()) - - if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: - # Concatenate, convert once, then split - concatenated = torch.cat(non_none_weights) - converted = concatenated.to(target_dtype) - split_weights = torch.split(converted, sizes) - - # Rebuild master_weights list with converted tensors - converted_master_weights = list(master_weights) - for idx, split_w, orig_mw in zip( - non_none_indices, split_weights, [master_weights[i] for i in non_none_indices] - ): - converted_master_weights[idx] = split_w.view(orig_mw.shape) - master_weights = converted_master_weights - for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): @@ -176,42 +136,37 @@ def quantize_master_weights( if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() + if master_weight is not None: + # When not using fp8/fp4_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8/fp4_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + quantizer = model_weight._get_quantizer() if isinstance(quantizer, NVFP4Quantizer): - # NVFP4: master_weight dtype conversion already done above nvfp4_params.append( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) + elif isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8BlockQuantizer): + blockwise_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + mxfp8_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) else: - # FP8: convert master_weight to model dtype - if master_weight is not None: - # When not using fp8_primary_weights, the master_weight (fp32) is first cast to - # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when - # fp8_primary_weights is enabled, we still keep this logic to keep numerical - # consistency. So here we cast the master_weight to model_weight.dtype. - master_weight = master_weight.to(model_weight.dtype) - - if isinstance(quantizer, Float8Quantizer): - delayed_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8CurrentScalingQuantizer): - current_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8BlockQuantizer): - blockwise_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, MXFP8Quantizer): - mxfp8_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - else: - raise ValueError( - f"quantize_master_weights for {type(quantizer)} is not supported yet" - ) + raise ValueError(f"quantize_master_weights for {type(quantizer)} is not supported yet") extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: