[Dev] Support optimizer offload when enable --fp8-param-gather #2788
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem Description
In the current Megatron codebase, enabling
--fp8-param-gatheris not supported when FP8 blockwise mode is used. The existing implementation forcibly prohibits enabling both options at the same time via an assert.This PR provides a solution that allows both options to be enabled simultaneously.
Our Solution
We first modified the
use_precision_aware_optimizer_no_fp8_or_ds_fp8flag so that it is set to true whenoptimizer_cpu_offloadis enabled andfp8_param_gatheris disabled. Therefore, this change does not affect the original logic in which enablingfp8_param_gatherwas not allowed.All the code we added includes the following check:
if isinstance(self.optimizer, HybridDeviceOptimizer)Therefore, the conditions and logic we introduced apply only when optimizer offload is enabled. Under optimizer offload, the only case in which
use_precision_aware_optimizer_no_fp8_or_ds_fp8can be False is when FP8 is enabled together withfp8_param_gather.In the
_copy_model_params_to_main_paramsfunction, we need to explicitly perform a parameter copy for HDO whenfp8-param-gatheris enabled. This is because in other cases the code usesmodel_shard, which is a reference, whereas whenfp8-param-gatheris enabled a clone is used, so a manual copy is required.For the same reason, we also modified
HybridDeviceOptimizerin function_get_sub_optimizer_param_groupsto establish the linkage between the optimizer parameters and the cloned parameters.At the same time, we modified the optimizer state loading process in function
_set_main_param_and_optimizer_states. This ensures that after loading a checkpoint, the concrete parameters in HDO are properly updated, thereby maintaining numerical correctness.Evaluation
We conducted experiments comparing accuracy under different configurations. The results consistently validate the correctness of our changes and show that they do not affect the correctness of unmodified code paths. The experiments were conducted on a single-node, 8-GPU setup using the dsv2lite model.
1.1 random param initialization
When no checkpoint is loaded, there is no accuracy discrepancy, both before and after the changes.
1.2 load checkpoint without optimizer state
--fp8-format e4m3 --fp8-recipe blockwise --fp8-param-gather --use-precision-aware-optimizer --optimizer-cpu-offload --optimizer-offload-fraction 1.0 --load $CHECKPOINT_PATH --no-load-optim --no-load-rngFrom the accuracy curves, we can see that before the fix, enabling
--fp8-param-gatherwith optimizer offload led to a very large discrepancy compared to the non-offload case, whereas after the fix the discrepancy is reduced to the order of 1e-4.1.3 load checkpoint with optimizer state
From the accuracy curves, it is evident that before the fix, enabling
--fp8-param-gatherwith optimizer offload resulted in a very large discrepancy compared to the non-offload case, whereas after the fix the discrepancy is reduced to the order of 1e-4.1.4 no --fp8-param-gather after change
The experiments show that the changes do not affect correctness when
fp8-param-gatheris disabled.Summary
Overall, we enabled optimizer CPU offload when FP8 is not
delayedand--fp8-param-gatheris enabled, while ensuring numerical correctness and preserving all existing logic.