Skip to content

Conversation

@Baidu-AIAK
Copy link

Problem Description

In the current Megatron codebase, enabling --fp8-param-gather is not supported when FP8 blockwise mode is used. The existing implementation forcibly prohibits enabling both options at the same time via an assert.

if args.optimizer_cpu_offload:
    assert args.use_precision_aware_optimizer, (
        "The optimizer cpu offload must be used in conjunction with `--use-precision-aware-optimizer`, "
        "as the hybrid device optimizer reuses the code path of this flag."
    )
    assert not args.fp8_param_gather or args.fp8_recipe == "delayed", (
        "When `--fp8-param-gather` is enabled, the optimizer cpu offload "
        "must be used in conjunction with `--fp8-recipe delayed`."
    )

This PR provides a solution that allows both options to be enabled simultaneously.

Our Solution

self.use_precision_aware_optimizer_no_fp8_or_ds_fp8 = (
    self.use_precision_aware_optimizer
    and (
        self.main_params_dtype != torch.float32
        or (self.fp8_recipe is None or self.fp8_recipe == "delayed")
        or (self.optimizer_cpu_offload and not self.fp8_param_gather)
    )
)

We first modified the use_precision_aware_optimizer_no_fp8_or_ds_fp8 flag so that it is set to true when optimizer_cpu_offload is enabled and fp8_param_gather is disabled. Therefore, this change does not affect the original logic in which enabling fp8_param_gather was not allowed.

All the code we added includes the following check:
if isinstance(self.optimizer, HybridDeviceOptimizer)
Therefore, the conditions and logic we introduced apply only when optimizer offload is enabled. Under optimizer offload, the only case in which use_precision_aware_optimizer_no_fp8_or_ds_fp8 can be False is when FP8 is enabled together with fp8_param_gather.

if isinstance(self.optimizer, HybridDeviceOptimizer):
    # Copy model groups to shard groups. 
    # HDO uses `shard_params` as param_groups:
    # bf16 case, `shard_params` shares the same underlying storage, after loading ckpt, it was updated.
    # fp8 case, `shard_params` is fp32 copy of model params, after loading ckpt, it is not updated, so the 
    # explicit "model param -> shard param copy" is needed.
    if not self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
        copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups)
        copy_group_params(self.model_fp32_groups, self.shard_fp32_groups)
    self.optimizer.update_fp32_param_by_new_param()
    return

In the _copy_model_params_to_main_params function, we need to explicitly perform a parameter copy for HDO when fp8-param-gather is enabled. This is because in other cases the code uses model_shard, which is a reference, whereas when fp8-param-gather is enabled a clone is used, so a manual copy is required.

if self.param_update_in_fp32:
    # In FP8 case, the passed in param_groups is fp32 shard main param, so just do a self-reference
    if param.dtype != torch.float32:
        param = param.detach().clone().float()

For the same reason, we also modified HybridDeviceOptimizer in function _get_sub_optimizer_param_groups to establish the linkage between the optimizer parameters and the cloned parameters.

if isinstance(self.optimizer, HybridDeviceOptimizer):
    for k, v in tensors.items():
        if k == "param":
            k = "master_param"
        optim_state[k] = v

At the same time, we modified the optimizer state loading process in function _set_main_param_and_optimizer_states. This ensures that after loading a checkpoint, the concrete parameters in HDO are properly updated, thereby maintaining numerical correctness.

Evaluation

We conducted experiments comparing accuracy under different configurations. The results consistently validate the correctness of our changes and show that they do not affect the correctness of unmodified code paths. The experiments were conducted on a single-node, 8-GPU setup using the dsv2lite model.

1.1 random param initialization

fp8-no-load-ckpt parameter settings:
--fp8-format e4m3
--fp8-recipe blockwise
--fp8-param-gather
--use-precision-aware-optimizer
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0

When no checkpoint is loaded, there is no accuracy discrepancy, both before and after the changes.

1.2 load checkpoint without optimizer state

fp8-load-ckpt-no-load-opt parameter settings:
--fp8-format e4m3
--fp8-recipe blockwise
--fp8-param-gather
--use-precision-aware-optimizer
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0

--load $CHECKPOINT_PATH
--no-load-optim
--no-load-rng

From the accuracy curves, we can see that before the fix, enabling --fp8-param-gather with optimizer offload led to a very large discrepancy compared to the non-offload case, whereas after the fix the discrepancy is reduced to the order of 1e-4.

1.3 load checkpoint with optimizer state

fp8-load-ckpt-load-opt parameter settings:
--fp8-format e4m3
--fp8-recipe blockwise
--fp8-param-gather
--use-precision-aware-optimizer
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0

--load $CHECKPOINT_PATH
#--no-load-optim
#--no-load-rng

From the accuracy curves, it is evident that before the fix, enabling --fp8-param-gather with optimizer offload resulted in a very large discrepancy compared to the non-offload case, whereas after the fix the discrepancy is reduced to the order of 1e-4.

1.4 no --fp8-param-gather after change

no-fp8-param-gather parameter settings:
--fp8-format e4m3
--fp8-recipe blockwise
#--fp8-param-gather
--use-precision-aware-optimizer
--optimizer-cpu-offload
--optimizer-offload-fraction 1.0

--load $CHECKPOINT_PATH
#--no-load-optim
#--no-load-rng

The experiments show that the changes do not affect correctness when fp8-param-gather is disabled.

Summary

Overall, we enabled optimizer CPU offload when FP8 is not delayed and --fp8-param-gather is enabled, while ensuring numerical correctness and preserving all existing logic.

@Baidu-AIAK Baidu-AIAK requested review from a team as code owners December 31, 2025 08:15
@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 31, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@BestJuly BestJuly added the dev branch Dev branch related issues and development label Jan 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request dev branch Dev branch related issues and development

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants