Skip to content

Fix explicit sharding for Deepseek#3595

Open
Shuwen-Fang wants to merge 7 commits intomainfrom
explicitpp
Open

Fix explicit sharding for Deepseek#3595
Shuwen-Fang wants to merge 7 commits intomainfrom
explicitpp

Conversation

@Shuwen-Fang
Copy link
Copy Markdown
Collaborator

@Shuwen-Fang Shuwen-Fang commented Apr 7, 2026

Description

This PR fixes explicit sharding for deepseek by specifying correct sharding for expert weights, and add tests with ds3-test for different parallelisms.

Tests

  • Verified pytest tests/integration/smoke/train_smoke_test.py::Train -v passed
  • Added additional test coverage for explicit sharding with deepseek

Tested the following configs to validate correctness with deepseek3-test:

ici_fsdp_parallelism=2 \
ici_expert_parallelism=8 \
use_ring_of_experts=false \
ici_fsdp_parallelism=2 \
ici_expert_parallelism=8 \
use_ring_of_experts=false \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=4 \
ici_fsdp_parallelism=16 \
ici_fsdp_parallelism=8\
ici_tensor_transpose_parallelism=2 \

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 7, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@Shuwen-Fang Shuwen-Fang changed the title temp Fix explicit sharding for moe models Apr 7, 2026
@Shuwen-Fang Shuwen-Fang self-assigned this Apr 8, 2026
@Shuwen-Fang Shuwen-Fang changed the title Fix explicit sharding for moe models Fix explicit sharding for Deepseek Apr 8, 2026
]
)

@parameterized.named_parameters(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NuojCheng these tests depend on the hardware type but should work for all topologies with >= 2 devices. wdyt?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think we have v6e-4 for CI tests


w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
print("shuwen w0 kernel init:", jax.typeof(w0_kernel))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember to remove

w1_bias = maybe_shard_with_name(w1_bias, w1_bias_ns, self.config.shard_mode)
if wo_bias is not None:
wo_bias = maybe_shard_with_name(wo_bias, wo_bias_ns, self.config.shard_mode)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI we did similar things but in a different style in attention_op.py, see

def _maybe_shard_with_pspec(inputs, pspec: jax.sharding.PartitionSpec | None):
# decoder_segment_ids can be None
if pspec is None:
return None
sharding = NamedSharding(self.mesh, pspec)
return maybe_shard_with_name(
inputs,
sharding,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
extra_stack_level=1,
)
. It is optional to make the change

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
jax.config.update("jax_default_prng_impl", "unsafe_rbg") #threefry2x32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will remove, i encountered some errors with unsafe_rbg when running on vm


def test_tiny_config_explicit_shardmode_deepseek(self):
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
# Tests the Dense Matmul codepath
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove dense matmul test, just keep sparse matmul

@Shuwen-Fang Shuwen-Fang requested a review from NuojCheng April 8, 2026 17:35
]
)

def test_tiny_config_explicit_shardmode_deepseek(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use train_compile test instead (for speed?)

("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]),
("fsdp", ["ici_fsdp_parallelism=-1"]),
)
def test_parallelism_configs(self, parallelism_args):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be train_compile instead as well (compile only should be faster than running on real TPUs)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NuojCheng I think train compile should work in this case. Was there a reason for using train instead of train compile in the rest of this test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either one works fine. We can merge this one and migrate all explicit sharding tests to train compile in another PR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants