[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795
[FSDP2/Megatron-FSDP/DCP] If model parameters are DTensors, optimizer states should also be DTensors.#2795cspades wants to merge 12 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR correctly resolves the FSDP2/DCP checkpointing bug where Confidence Score: 5/5Safe to merge — the bug is correctly fixed end-to-end and covered by a new two-phase resharding test. All previously identified issues (FP8 global shape mismatch, isinstance(param) vs isinstance(state) check, DTensor parity guard) are fully addressed. The new code consistently wraps/unwraps DTensor states at every entry point, DTensor.from_local calls correctly use the global shape/stride from the parameter, and the two-phase DCP resharding test validates the exact failure scenario. No P0 or P1 issues remain. No files require special attention; the core optimizer logic in fused_adam.py has been reviewed thoroughly and the changes are correct. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[param is DTensor] -->|_initialize_state| B[Create local data tensor from local_param]
B -->|dtype == uint8| C[Float8Tensor via quantizer.make_empty data.shape]
B -->|other dtypes| D[plain Tensor torch.empty_like]
C --> E[DTensor.from_local wrap state as DTensor]
D --> E
E --> F[self.state param state_name = DTensor]
F -->|get_unscaled_state| G[Unwrap via ._local_tensor]
G --> H[Return unscaled float32 value]
F -->|set_scaled_state| I[Unwrap via ._local_tensor]
I --> J[_apply_scale or copy_ in-place on local tensor]
F -->|state_dict| K[get_unscaled_state returns local fp32]
K -->|param is DTensor| L[Re-wrap as DTensor with global shape/stride]
L --> M[DCP checkpoint has DTensor states]
M -->|load_state_dict| N[super.load_state_dict then reset self.state]
N --> O[Unpack DTensor via ._local_tensor]
O --> P[set_scaled_state re-creates DTensor-wrapped state]
F -->|step| Q[Assert DTensor parity: param/grad/states]
Q --> R[Unwrap all states via ._local_tensor]
R --> S[Pass local tensors to CUDA kernel]
Reviews (10): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
@cspades could you please elaborate on the downstream error/issue caused. As in what happens if we load the unsharded tensor for optimizer state as plain tensor instead of DTensor? |
Here is how I understand it, @shjwudp correct me if I am wrong about the Megatron-FSDP details, as I still need to reproduce the bug and ensure this PR fixes it. I believe a customer reported this bug?
The fix is to keep the optimizer state in |
pstjohn
left a comment
There was a problem hiding this comment.
TLDR it would fail if you train on 4 ranks and load on 2 ranks, this adds a test for this.
(among other issues with mFSDP)
|
/te-ci L1 pytorch |
1 similar comment
|
/te-ci L1 pytorch |
|
https://gitlab-master.nvidia.com/dl/transformerengine/transformerengine/-/pipelines/47362013 Some random error: Rebased and rerunning CI. |
…sor. Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
Add Greptile bug-fixes. Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Cory Ye <44509866+cspades@users.noreply.github.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
… re-sharding test. Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
1cf3948 to
f095376
Compare
… as those tests need to be run in sequence. Signed-off-by: Cory Ye <cye@nvidia.com>
e3ae4ea to
0164aaa
Compare
for more information, see https://pre-commit.ci
Description
DTensor(QuantizedTensor)(FSDP2-only) use case introduced in Add fused_adam, quantized_model_init, and fsdp2 example #2698 whereFusedAdam's optimizer state is converted into a non-distributed Tensor, which is loaded as a global / un-sharded state dictionary by Torch DCP.This will break Megatron-FSDP checkpointing with DCP in MLM:
and causes FSDP2 to not have a distributed optimizer state. (See re-shard tests.)
Details
dcp_resharding_savemust be run with and beforedcp_resharding_load.dcp_resharding_savedeletes existing checkpoint directories, and write a new DCP checkpoint, whiledcp_resharding_loadfinallydeletes the saved DCP checkpoint as well.Testing
--use-precision-aware-optimizerType of change
Checklist: