Skip to content

[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795

Open
cspades wants to merge 12 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp
Open

[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795
cspades wants to merge 12 commits intoNVIDIA:mainfrom
cspades:cye/fused-adam-dcp

Conversation

@cspades
Copy link
Copy Markdown
Member

@cspades cspades commented Mar 24, 2026

Description

This will break Megatron-FSDP checkpointing with DCP in MLM:

[rank0]:   File "/opt/megatron-lm/megatron/training/checkpointing.py", line 1015, in preprocess_fsdp_dtensor_state_dict
[rank0]:     model_state_dict, optimizer_state_dict = handle_swiglu_in_state_dict(
[rank0]:                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/megatron-lm/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 318, in handle_swiglu_in_state_dict
[rank0]:     weight_w, weight_v = split_swiglu_linear_fc1(
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/megatron-lm/megatron/core/transformer/fsdp_dtensor_checkpoint.py", line 246, in split_swiglu_linear_fc1
[rank0]:     local_tensor = data.to_local()
[rank0]:                    ^^^^^^^^^^^^^
[rank0]: AttributeError: 'Tensor' object has no attribute 'to_local'

and causes FSDP2 to not have a distributed optimizer state. (See re-shard tests.)

Details

  • We wrap the optimizer state as a DTensor matching the distribution characteristics of the original DTensor parameter the state is associated with. So the state is always a DTensor while its compute and steps are applied to the local Tensor. (There are very few line changes if you ignore variable renaming.)
  • Test Structure
    • dcp_resharding_save must be run with and before dcp_resharding_load.
    • dcp_resharding_save deletes existing checkpoint directories, and write a new DCP checkpoint, while dcp_resharding_load finally deletes the saved DCP checkpoint as well.

Testing

  • TE CI/CD
TE_PATH=/workspace/TransformerEngine ./qa/L1_pytorch_distributed_unittest/test.sh
pytest -v -s tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding
...
tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding[DelayedScaling] PASSED
tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding[Float8CurrentScaling] PASSED
tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding[Float8BlockScaling] XFAIL (Float8BlockScaling + FSDP2 with 4-rank sharding fails on Blackwell (SM10+): swizzle_block_scaling_to_mxfp8_scaling...)
tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding[MXFP8BlockScaling] XFAIL (MXFP8BlockScaling: FusedAdam CUDA kernel does not support MXFP8 quantized tensors, causing illegal memory access. F...)
tests/pytorch/distributed/test_torch_fsdp2.py::test_fsdp2_fused_adam_dcp_resharding[NVFP4BlockScaling] XFAIL (NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() which calls data_ptr() on NVFP4Tensor wrapper...)
2 passed, 3 xfailed, 2 warnings in 145.61s (0:02:25)
  • Megatron-LM + --use-precision-aware-optimizer
# TE@00ba0b493c27f32e2f210b0022132c50da78dac7 (Llama 8B + Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 15:18:07.588704] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9614.0 | throughput per GPU (TFLOP/s/GPU): 1403.6 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131176E+00 | loss scale: 1.0 | grad norm: 5.337 | number of skipped iterations:   0 | number of nan iterations:   0 |

# This PR (Llama 8B Precision-Aware Optimizer + FP8Blockwise + TP2 + GB300)
[2026-03-25 14:58:55.856189] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 9588.0 | throughput per GPU (TFLOP/s/GPU): 1407.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.131045E+00 | loss scale: 1.0 | grad norm: 5.336 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • This PR fixes the Megatron-FSDP un-even DTensor preprocessing error, and training works fine:
# Megatron-FSDP + self.use_precision_aware_optimizer=True + --use-precision-aware-optimizer + BF16 + HFSDP
# And: --save-interval 1 and --ckpt-format fsdp_dtensor to reproduce the checkpointing error quickly.
[2026-03-31 08:37:02.253124] iteration        3/15258789 | consumed samples:          384 | elapsed time per iteration (ms): 18363.2 | throughput per GPU (TFLOP/s/GPU): 734.8 | learning rate: 1.474559E-08 | global batch size:   128 | lm loss: 1.213362E+01 | loss scale: 1.0 | grad norm: 0.000 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • To reproduce the FSDP2 error motivating this PR, use the broken FusedAdam code before this PR/commit and run this test that saves a checkpoint with 4 GPUs and loads the saved checkpoint with 2 GPUs:
torchrun --nproc-per-node 4 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_save" && torchrun --nproc-per-node 2 -m pytest tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py -v -s -k "dcp_resharding_load"

E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg
E           Traceback (most recent call last): (RANK 1)
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/utils.py", line 193, in reduce_scatter
E               local_data = map_fun()
E                            ^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/logger.py", line 90, in wrapper
E               result = func(*args, **kwargs)
E                        ^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 269, in local_step
E               local_plan = planner.create_local_plan()
E                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 352, in create_local_plan
E               return create_default_local_load_plan(
E                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E             File "/usr/local/lib/python3.12/dist-packages/torch/distributed/checkpoint/default_planner.py", line 485, in create_default_local_load_plan
E               raise ValueError(
E           ValueError: Size mismatch between saved torch.Size([64]) and current: torch.Size([128]) for optimizer.state.0.exp_avg

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@cspades cspades marked this pull request as ready for review March 24, 2026 17:42
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 24, 2026

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR correctly resolves the FSDP2/DCP checkpointing bug where FusedAdam optimizer states for DTensor-wrapped parameters were stored as plain local tensors instead of DTensors, causing DCP to treat them as unsharded global tensors and breaking cross-topology checkpoint resharding. The fix ensures optimizer states always mirror the distribution characteristics of their associated parameters throughout the entire lifecycle (_initialize_state, get/set_scaled_state, state_dict, load_state_dict, step). A two-phase DCP resharding test (save with 4 ranks, load with 2 ranks) is added to reproduce and verify the fix.

Confidence Score: 5/5

Safe to merge — the bug is correctly fixed end-to-end and covered by a new two-phase resharding test.

All previously identified issues (FP8 global shape mismatch, isinstance(param) vs isinstance(state) check, DTensor parity guard) are fully addressed. The new code consistently wraps/unwraps DTensor states at every entry point, DTensor.from_local calls correctly use the global shape/stride from the parameter, and the two-phase DCP resharding test validates the exact failure scenario. No P0 or P1 issues remain.

No files require special attention; the core optimizer logic in fused_adam.py has been reviewed thoroughly and the changes are correct.

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/fused_adam.py Core bug fix: optimizer states for DTensor params now initialized, stored, and accessed as DTensors throughout the lifecycle; FP8 state shape mismatch (param.shape→data.shape) also corrected
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Adds two-phase DCP resharding test (save with 4 ranks, load with 2 ranks) to reproduce and verify the DTensor optimizer state bug fix; xfail guards for known broken recipes added
tests/pytorch/distributed/test_torch_fsdp2.py Adds test_fsdp2_fused_adam_dcp_resharding orchestrator that sequentially invokes the two-phase save/load torchrun subprocesses; excludes resharding tests from the general FSDP2 suite to avoid topology conflicts

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[param is DTensor] -->|_initialize_state| B[Create local data tensor from local_param]
    B -->|dtype == uint8| C[Float8Tensor via quantizer.make_empty data.shape]
    B -->|other dtypes| D[plain Tensor torch.empty_like]
    C --> E[DTensor.from_local wrap state as DTensor]
    D --> E
    E --> F[self.state param state_name = DTensor]
    F -->|get_unscaled_state| G[Unwrap via ._local_tensor]
    G --> H[Return unscaled float32 value]
    F -->|set_scaled_state| I[Unwrap via ._local_tensor]
    I --> J[_apply_scale or copy_ in-place on local tensor]
    F -->|state_dict| K[get_unscaled_state returns local fp32]
    K -->|param is DTensor| L[Re-wrap as DTensor with global shape/stride]
    L --> M[DCP checkpoint has DTensor states]
    M -->|load_state_dict| N[super.load_state_dict then reset self.state]
    N --> O[Unpack DTensor via ._local_tensor]
    O --> P[set_scaled_state re-creates DTensor-wrapped state]
    F -->|step| Q[Assert DTensor parity: param/grad/states]
    Q --> R[Unwrap all states via ._local_tensor]
    R --> S[Pass local tensors to CUDA kernel]
Loading

Reviews (10): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 24, 2026

@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor?

Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?

  • Add fused_adam, quantized_model_init, and fsdp2 example #2698 introduced logic during the FusedAdam.__init__ such that if the TE model parameters are DTensor, then it will change the optimizer state to normal Tensor.
    • The reason is because empty_like does not pick up the correct dtype from DTensor (from in-line commentary), when the local data is QuantizedTensor. Note that Megatron-FSDP's main weights are FP32, not QuantizedTensor, so our code worked with the original FusedAdam.
  • When Megatron-FSDP (or Megatron-LM's distributed optimizer) performs its first optimizer.step(), Megatron-FSDP exposes FP32 DTensor main weights to the FusedAdam optimizer, and because of the above logic, normal Tensor optimizer states are constructed from the DTensor main weights.
  • Megatron-FSDP depends on DTensor optimizer states for DCP checkpointing of FusedAdam's state, because we employ un-even sharding. Instead, it now sees normal Tensors, and this may break our DCP integration and/or un-even DTensor metadata.

The fix is to keep the optimizer state in DTensor form if the model is in DTensor form, and localize or perform in-place operations to the local Tensor for all FusedAdam operations.

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 25, 2026

Copy link
Copy Markdown
Contributor

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR it would fail if you train on 4 ranks and load on 2 ranks, this adds a test for this.

(among other issues with mFSDP)

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 25, 2026

/te-ci L1 pytorch

1 similar comment
@vthumbe1503
Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47362013

Some random error:

ERROR: file or directory not found: /opt/pytorch/lightning-thunder/thunder/tests/test_transformer_engine_executor.py

Rebased and rerunning CI.

cspades and others added 10 commits March 31, 2026 09:00
…sor.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Add Greptile bug-fixes.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
… re-sharding test.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fused-adam-dcp branch from 1cf3948 to f095376 Compare March 31, 2026 16:07
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

… as those tests need to be run in sequence.

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades cspades force-pushed the cye/fused-adam-dcp branch from e3ae4ea to 0164aaa Compare March 31, 2026 19:59
@cspades
Copy link
Copy Markdown
Member Author

cspades commented Mar 31, 2026

@cspades cspades changed the title If model parameters are DTensors, optimizer states should also be DTensors. [FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors. Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants