Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
| ] | ||
| ) | ||
|
|
||
| @parameterized.named_parameters( |
There was a problem hiding this comment.
@NuojCheng these tests depend on the hardware type but should work for all topologies with >= 2 devices. wdyt?
There was a problem hiding this comment.
Yes. I think we have v6e-4 for CI tests
src/maxtext/layers/moe.py
Outdated
|
|
||
| w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) | ||
| print("shuwen w0 kernel init:", jax.typeof(w0_kernel)) |
| w1_bias = maybe_shard_with_name(w1_bias, w1_bias_ns, self.config.shard_mode) | ||
| if wo_bias is not None: | ||
| wo_bias = maybe_shard_with_name(wo_bias, wo_bias_ns, self.config.shard_mode) | ||
|
|
There was a problem hiding this comment.
Just FYI we did similar things but in a different style in attention_op.py, see
maxtext/src/maxtext/layers/attention_op.py
Lines 1487 to 1498 in 82ece9d
|
|
||
| def main(argv: Sequence[str]) -> None: | ||
| jax.config.update("jax_default_prng_impl", "unsafe_rbg") | ||
| jax.config.update("jax_default_prng_impl", "unsafe_rbg") #threefry2x32 |
There was a problem hiding this comment.
will remove, i encountered some errors with unsafe_rbg when running on vm
|
|
||
| def test_tiny_config_explicit_shardmode_deepseek(self): | ||
| test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable | ||
| # Tests the Dense Matmul codepath |
There was a problem hiding this comment.
let's remove dense matmul test, just keep sparse matmul
| ] | ||
| ) | ||
|
|
||
| def test_tiny_config_explicit_shardmode_deepseek(self): |
There was a problem hiding this comment.
should we use train_compile test instead (for speed?)
| ("fsdp_expert_no_roe", ["ici_fsdp_parallelism=-1", "ici_expert_parallelism=2", "use_ring_of_experts=False"]), | ||
| ("fsdp", ["ici_fsdp_parallelism=-1"]), | ||
| ) | ||
| def test_parallelism_configs(self, parallelism_args): |
There was a problem hiding this comment.
I think this can be train_compile instead as well (compile only should be faster than running on real TPUs)
There was a problem hiding this comment.
@NuojCheng I think train compile should work in this case. Was there a reason for using train instead of train compile in the rest of this test?
There was a problem hiding this comment.
I think either one works fine. We can merge this one and migrate all explicit sharding tests to train compile in another PR.
Description
This PR fixes explicit sharding for deepseek by specifying correct sharding for expert weights, and add tests with ds3-test for different parallelisms.
Tests
pytest tests/integration/smoke/train_smoke_test.py::Train -vpassedTested the following configs to validate correctness with deepseek3-test:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.