From 0e244cf5d7994881ae550e7b03d9f9f3e7f6383b Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 3 Apr 2026 13:58:31 -0700 Subject: [PATCH 1/3] fix memory overheads Signed-off-by: qiyuw --- transformer_engine/pytorch/tensor/utils.py | 107 +++++---------------- 1 file changed, 26 insertions(+), 81 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index c80bc8aaa4..4f001588e2 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -118,100 +118,45 @@ def quantize_master_weights( else: use_fsdp_shard_model_weights = True - # Batch convert master_weights to model dtype for NVFP4 (single kernel instead of N kernels) - # Check if there are any NVFP4 weights - has_nvfp4 = any( - isinstance(w._get_quantizer(), NVFP4Quantizer) - for w in model_weights - if hasattr(w, "_get_quantizer") - ) - if has_nvfp4 and len(model_weights) > 0: - # Find target dtype from first NVFP4 weight - target_dtype = None - for w in model_weights: - if hasattr(w, "_get_quantizer") and isinstance(w._get_quantizer(), NVFP4Quantizer): - target_dtype = w.dtype - break - - if target_dtype is not None: - # Collect non-None master_weights and their indices - non_none_indices = [] - non_none_weights = [] - sizes = [] - for i, mw in enumerate(master_weights): - if mw is not None: - non_none_indices.append(i) - non_none_weights.append(mw.view(-1)) - sizes.append(mw.numel()) - - if len(non_none_weights) > 0 and non_none_weights[0].dtype != target_dtype: - # Concatenate, convert once, then split - concatenated = torch.cat(non_none_weights) - converted = concatenated.to(target_dtype) - split_weights = torch.split(converted, sizes) - - # Rebuild master_weights list with converted tensors - converted_master_weights = list(master_weights) - for idx, split_w, orig_mw in zip( - non_none_indices, split_weights, [master_weights[i] for i in non_none_indices] - ): - converted_master_weights[idx] = split_w.view(orig_mw.shape) - master_weights = converted_master_weights - for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): - # Clear `_high_precision_init_val` of model_weight automatically. - # - Master weights are initialized from model weights, if we use fp8 primary weights to - # initialize master weights, the numerical values of master weights are not consistent - # with the numerical values when we initialize them from bf16/fp16 weights. - # - So we add a `_high_precision_init_val` attribute to each model weight to store the - # original bf16/fp16 weight on cpu before casting it to fp8. And users can use - # `get_high_precision_init_val` to get this cpu tensor. - # - This cpu tensor is not needed once the master weight is initialized, so users should - # call `clear_high_precision_init_val` to remove it after master weight is initialized. - # - In case users don't call `clear_high_precision_init_val`, we will clear it automatically - # here. It's safe to clear the `_high_precision_init_val` at this time because this - # function is supposed to be called after the master weights are initialized and updated. if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() + if master_weight is not None: + # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when + # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # consistency. So here we cast the master_weight to model_weight.dtype. + master_weight = master_weight.to(model_weight.dtype) + quantizer = model_weight._get_quantizer() if isinstance(quantizer, NVFP4Quantizer): - # NVFP4: master_weight dtype conversion already done above nvfp4_params.append( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) + elif isinstance(quantizer, Float8Quantizer): + delayed_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8CurrentScalingQuantizer): + current_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, Float8BlockQuantizer): + blockwise_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) + elif isinstance(quantizer, MXFP8Quantizer): + mxfp8_scaling_params.append( + (model_weight, master_weight, start_offset, fsdp_shard_model_weight) + ) else: - # FP8: convert master_weight to model dtype - if master_weight is not None: - # When not using fp8_primary_weights, the master_weight (fp32) is first cast to - # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when - # fp8_primary_weights is enabled, we still keep this logic to keep numerical - # consistency. So here we cast the master_weight to model_weight.dtype. - master_weight = master_weight.to(model_weight.dtype) - - if isinstance(quantizer, Float8Quantizer): - delayed_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8CurrentScalingQuantizer): - current_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, Float8BlockQuantizer): - blockwise_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - elif isinstance(quantizer, MXFP8Quantizer): - mxfp8_scaling_params.append( - (model_weight, master_weight, start_offset, fsdp_shard_model_weight) - ) - else: - raise ValueError( - f"quantize_master_weights for {type(quantizer)} is not supported yet" - ) + raise ValueError( + f"quantize_master_weights for {type(quantizer)} is not supported yet" + ) extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: From 145e164a592a2f7a4ff57769147a511f80031c1d Mon Sep 17 00:00:00 2001 From: qiyuw Date: Fri, 3 Apr 2026 14:06:50 -0700 Subject: [PATCH 2/3] comments Signed-off-by: qiyuw --- transformer_engine/pytorch/tensor/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 4f001588e2..1b59d7e22a 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -121,13 +121,25 @@ def quantize_master_weights( for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip( model_weights, master_weights, start_offsets, fsdp_shard_model_weights ): + # Clear `_high_precision_init_val` of model_weight automatically. + # - Master weights are initialized from model weights, if we use fp8 primary weights to + # initialize master weights, the numerical values of master weights are not consistent + # with the numerical values when we initialize them from bf16/fp16 weights. + # - So we add a `_high_precision_init_val` attribute to each model weight to store the + # original bf16/fp16 weight on cpu before casting it to fp8. And users can use + # `get_high_precision_init_val` to get this cpu tensor. + # - This cpu tensor is not needed once the master weight is initialized, so users should + # call `clear_high_precision_init_val` to remove it after master weight is initialized. + # - In case users don't call `clear_high_precision_init_val`, we will clear it automatically + # here. It's safe to clear the `_high_precision_init_val` at this time because this + # function is supposed to be called after the master weights are initialized and updated. if hasattr(model_weight, "clear_high_precision_init_val"): model_weight.clear_high_precision_init_val() if master_weight is not None: - # When not using fp8_primary_weights, the master_weight (fp32) is first cast to + # When not using fp8/fp4_primary_weights, the master_weight (fp32) is first cast to # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when - # fp8_primary_weights is enabled, we still keep this logic to keep numerical + # fp8/fp4_primary_weights is enabled, we still keep this logic to keep numerical # consistency. So here we cast the master_weight to model_weight.dtype. master_weight = master_weight.to(model_weight.dtype) From a2074668b061dda4f52fd71adea01701afdcd51f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 21:14:17 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 1b59d7e22a..ba44c7a619 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -166,9 +166,7 @@ def quantize_master_weights( (model_weight, master_weight, start_offset, fsdp_shard_model_weight) ) else: - raise ValueError( - f"quantize_master_weights for {type(quantizer)} is not supported yet" - ) + raise ValueError(f"quantize_master_weights for {type(quantizer)} is not supported yet") extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: