Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 24 additions & 69 deletions transformer_engine/pytorch/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,46 +118,6 @@ def quantize_master_weights(
else:
use_fsdp_shard_model_weights = True

# Batch convert master_weights to model dtype for NVFP4 (single kernel instead of N kernels)
# Check if there are any NVFP4 weights
has_nvfp4 = any(
isinstance(w._get_quantizer(), NVFP4Quantizer)
for w in model_weights
if hasattr(w, "_get_quantizer")
)
if has_nvfp4 and len(model_weights) > 0:
# Find target dtype from first NVFP4 weight
target_dtype = None
for w in model_weights:
if hasattr(w, "_get_quantizer") and isinstance(w._get_quantizer(), NVFP4Quantizer):
target_dtype = w.dtype
break

if target_dtype is not None:
# Collect non-None master_weights and their indices
non_none_indices = []
non_none_weights = []
sizes = []
for i, mw in enumerate(master_weights):
if mw is not None:
non_none_indices.append(i)
non_none_weights.append(mw.view(-1))
sizes.append(mw.numel())

if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype:
# Concatenate, convert once, then split
concatenated = torch.cat(non_none_weights)
converted = concatenated.to(target_dtype)
split_weights = torch.split(converted, sizes)

# Rebuild master_weights list with converted tensors
converted_master_weights = list(master_weights)
for idx, split_w, orig_mw in zip(
non_none_indices, split_weights, [master_weights[i] for i in non_none_indices]
):
converted_master_weights[idx] = split_w.view(orig_mw.shape)
master_weights = converted_master_weights

for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip(
model_weights, master_weights, start_offsets, fsdp_shard_model_weights
):
Expand All @@ -176,42 +136,37 @@ def quantize_master_weights(
if hasattr(model_weight, "clear_high_precision_init_val"):
model_weight.clear_high_precision_init_val()

if master_weight is not None:
# When not using fp8/fp4_primary_weights, the master_weight (fp32) is first cast to
# bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when
# fp8/fp4_primary_weights is enabled, we still keep this logic to keep numerical
# consistency. So here we cast the master_weight to model_weight.dtype.
master_weight = master_weight.to(model_weight.dtype)

quantizer = model_weight._get_quantizer()

if isinstance(quantizer, NVFP4Quantizer):
# NVFP4: master_weight dtype conversion already done above
nvfp4_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8Quantizer):
delayed_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8BlockQuantizer):
blockwise_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer):
mxfp8_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
else:
# FP8: convert master_weight to model dtype
if master_weight is not None:
# When not using fp8_primary_weights, the master_weight (fp32) is first cast to
# bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when
# fp8_primary_weights is enabled, we still keep this logic to keep numerical
# consistency. So here we cast the master_weight to model_weight.dtype.
master_weight = master_weight.to(model_weight.dtype)

if isinstance(quantizer, Float8Quantizer):
delayed_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8CurrentScalingQuantizer):
current_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, Float8BlockQuantizer):
blockwise_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
elif isinstance(quantizer, MXFP8Quantizer):
mxfp8_scaling_params.append(
(model_weight, master_weight, start_offset, fsdp_shard_model_weight)
)
else:
raise ValueError(
f"quantize_master_weights for {type(quantizer)} is not supported yet"
)
raise ValueError(f"quantize_master_weights for {type(quantizer)} is not supported yet")

extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing]
if len(delayed_scaling_params) > 0:
Expand Down
Loading