From 7b7599330f81aff62f5b7f74d4e8fe2f3e79c599 Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Wed, 8 Apr 2026 18:40:27 +0000 Subject: [PATCH] remove no exp in moe --- src/maxtext/configs/base.yml | 21 +- .../pipeline-large-moe.yml | 9 +- .../custom_mesh_and_rule/pure-fsdp.yml | 3 - src/maxtext/configs/inference/inference.yml | 10 +- src/maxtext/configs/inference/vllm.yml | 10 +- src/maxtext/configs/post_train/rl_mt_jt.yml | 8 +- src/maxtext/layers/moe.py | 243 +++++++----------- .../tpu7x-16/slice_1/input_shardings.json | 4 +- .../tpu7x-16/slice_1/logical_shardings.json | 36 +-- .../tpu7x-16/slice_1/named_shardings.json | 9 +- .../tpu7x-16/slice_4/input_shardings.json | 4 +- .../tpu7x-16/slice_4/logical_shardings.json | 36 +-- .../tpu7x-16/slice_4/named_shardings.json | 9 +- .../v5p-16/slice_1/input_shardings.json | 4 +- .../v5p-16/slice_1/logical_shardings.json | 36 +-- .../v5p-16/slice_1/named_shardings.json | 9 +- .../v5p-16/slice_4/input_shardings.json | 4 +- .../v5p-16/slice_4/logical_shardings.json | 36 +-- .../v5p-16/slice_4/named_shardings.json | 9 +- .../v6e-16/slice_1/input_shardings.json | 4 +- .../v6e-16/slice_1/logical_shardings.json | 36 +-- .../v6e-16/slice_1/named_shardings.json | 9 +- .../v6e-16/slice_4/input_shardings.json | 4 +- .../v6e-16/slice_4/logical_shardings.json | 36 +-- .../v6e-16/slice_4/named_shardings.json | 9 +- .../tpu7x-16/slice_1/input_shardings.json | 4 +- .../tpu7x-16/slice_1/logical_shardings.json | 84 +++--- .../tpu7x-16/slice_4/input_shardings.json | 4 +- .../tpu7x-16/slice_4/logical_shardings.json | 84 +++--- .../v5p-16/slice_1/input_shardings.json | 4 +- .../v5p-16/slice_1/logical_shardings.json | 84 +++--- .../v5p-16/slice_4/input_shardings.json | 4 +- .../v5p-16/slice_4/logical_shardings.json | 84 +++--- .../v6e-16/slice_1/input_shardings.json | 4 +- .../v6e-16/slice_1/logical_shardings.json | 84 +++--- .../v6e-16/slice_4/input_shardings.json | 4 +- .../v6e-16/slice_4/logical_shardings.json | 84 +++--- 37 files changed, 520 insertions(+), 606 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 79171755c5..99d1afe436 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -455,8 +455,7 @@ custom_mesh_and_rule: "" # replace default mesh and logical rule by specifying y mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert']], - ['activation_batch_no_exp_moe', ['data', 'fsdp', 'fsdp_transpose']], + ['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['activation_heads', ['tensor', 'tensor_transpose', 'sequence','tensor_sequence','autoregressive']], @@ -477,6 +476,7 @@ logical_axis_rules: [ ['activation_embed', ['tensor', 'tensor_transpose']], ['activation_embed_moe', ['tensor', 'tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose']], @@ -490,6 +490,7 @@ logical_axis_rules: [ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['decode_length', ['sequence']], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -499,18 +500,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], ['embed', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp', ['fsdp', 'sequence', 'context']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context' , 'expert']], - ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']], - ['embed_moe', ['fsdp', 'sequence', 'context', 'expert']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], - ['embed_no_exp_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], - ['embed_no_exp_moe', ['fsdp', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'tensor_transpose', 'context']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context']], + ['embed_moe', ['fsdp', 'sequence', 'context']], ['embed_tensor_transpose', ['tensor_transpose']], ['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']], ['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index fd71b98f39..305c7b55c9 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -30,8 +30,7 @@ mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert'] data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'expert']], - ['activation_batch_moe', ['data', 'fsdp', 'expert']], - ['activation_batch_no_exp_moe', ['data', 'fsdp']], + ['activation_batch_moe', ['data', 'fsdp']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']], ['activation_heads', ['tensor']], @@ -45,6 +44,7 @@ logical_axis_rules: [ ['activation_embed', ['tensor']], ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], + ['activation_mlp_moe', ['tensor']], ['activation_kv', ['tensor']], ['activation_kv_batch', ['data', 'fsdp']], ['activation_kv_head_dim', ['tensor']], @@ -52,15 +52,14 @@ logical_axis_rules: [ ['activation_stage', 'stage'], ['activation_exp', ['expert']], ['mlp', ['tensor']], + ['mlp_moe', ['tensor']], ['mlp_no_fsdp', ['tensor']], ['vocab', ['tensor']], ['heads', ['tensor']], ['q_heads', ['tensor']], ['kv_heads', ['tensor']], ['embed', ['fsdp', 'expert']], # remove context from embed sharding - ['embed_moe', ['fsdp', 'expert']], - ['embed_no_exp', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], + ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index 5d6939bad6..8b35fadff2 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -19,16 +19,13 @@ data_sharding: [['fsdp']] logical_axis_rules: [ ['activation_batch', ['fsdp']], ['activation_batch_moe', ['fsdp']], - ['activation_batch_no_exp_moe', ['fsdp']], ['activation_embed_and_logits_batch', ['fsdp']], ['activation_embed_and_logits_batch_sequence', ['fsdp']], ['activation_prefill_kv_batch', ['fsdp']], ['activation_kv_batch', ['fsdp']], ['decode_batch', ['fsdp']], ['embed', ['fsdp']], - ['embed_no_exp', ['fsdp']], ['embed_moe', ['fsdp']], - ['embed_no_exp_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], diff --git a/src/maxtext/configs/inference/inference.yml b/src/maxtext/configs/inference/inference.yml index 40bede0b50..55407b3edc 100644 --- a/src/maxtext/configs/inference/inference.yml +++ b/src/maxtext/configs/inference/inference.yml @@ -12,6 +12,7 @@ logical_axis_rules: [ ['activation_norm_length', ['tensor_sequence', 'sequence']], ['activation_embed', ['tensor_transpose']], ['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']], + ['activation_mlp_moe', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_kv', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['activation_prefill_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], @@ -25,6 +26,7 @@ logical_axis_rules: [ ['decode_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context_autoregressive']], ['decode_length', []], ['mlp', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], + ['mlp_moe', ['fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive']], ['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive','context_autoregressive']], ['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], ['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']], @@ -33,10 +35,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/configs/inference/vllm.yml b/src/maxtext/configs/inference/vllm.yml index 3664e5ce1a..ffa984296d 100644 --- a/src/maxtext/configs/inference/vllm.yml +++ b/src/maxtext/configs/inference/vllm.yml @@ -31,7 +31,6 @@ mesh_axes: ['data', 'attn_dp', 'model', 'expert', 'attn_dp_expert'] logical_axis_rules: [ ['activation_batch', ['data']], ['activation_batch_moe', []], - ['activation_batch_no_exp_moe', []], ['activation_embed_and_logits_batch', ['data', 'expert']], ['activation_embed_and_logits_batch_sequence', ['data', 'expert']], ['activation_heads', ['model', 'expert']], @@ -45,6 +44,7 @@ logical_axis_rules: [ ['activation_embed', ['model', 'attn_dp']], ['activation_embed_moe', ['model', 'attn_dp']], ['activation_mlp', ['model', 'attn_dp']], + ['activation_mlp_moe', ['model', 'attn_dp']], ['activation_kv', ['model']], ['activation_prefill_kv_batch', ['expert', 'attn_dp_expert']], ['activation_kv_batch', ['data']], @@ -56,8 +56,8 @@ logical_axis_rules: [ ['decode_batch', ['expert', 'attn_dp_expert']], ['decode_length', []], ['mlp', ['model', 'attn_dp']], + ['mlp_moe', ['model', 'attn_dp']], ['mlp_no_fsdp', ['model', 'attn_dp']], - ['moe_mlp', ['model', 'attn_dp']], ['vocab', ['model', 'attn_dp']], ['heads', ['model']], ['q_heads', ['model', 'expert']], @@ -66,11 +66,9 @@ logical_axis_rules: [ ['kv', []], ['embed', ['expert', 'attn_dp_expert']], ['embed', ['attn_dp_expert']], - ['embed_moe', ['expert', 'attn_dp_expert']], - ['embed_moe', ['attn_dp_expert']], + ['embed_moe', []], + ['embed_moe', []], ['embed_tensor_transpose', ['attn_dp', 'model']], - ['embed_no_exp', []], - ['embed_no_exp_moe', []], ['q_lora', ['expert', 'attn_dp_expert']], ['kv_lora', ['expert', 'attn_dp_expert']], ['norm', []], diff --git a/src/maxtext/configs/post_train/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml index c83addc01c..4383b1c4ac 100644 --- a/src/maxtext/configs/post_train/rl_mt_jt.yml +++ b/src/maxtext/configs/post_train/rl_mt_jt.yml @@ -49,10 +49,10 @@ logical_axis_rules: [ ['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']], ['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']], ['embed', ['fsdp', 'sequence', 'expert']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], - ['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], - ['embed_no_exp', ['fsdp', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive', 'tensor_transpose']], + ['embed_moe', ['fsdp', 'fsdp_transpose', 'sequence', 'context_autoregressive']], + ['embed_moe', ['fsdp', 'sequence', 'context_autoregressive']], ['norm', ['tensor', 'tensor_transpose', 'tensor_sequence']], ['layers', 'stage'], ['kv', []], diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7cc2227722..7eb723ef19 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -102,8 +102,8 @@ def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tupl def get_batchsplit_init_kernel_axes(): return ( - ("embed_no_exp", "fsdp_transpose_only", "expert_only"), - ("embed_no_exp", "fsdp_transpose_and_expert", None), + ("embed_moe", "fsdp_transpose_only", "expert_only"), + ("embed_moe", "fsdp_transpose_and_expert", None), ) @@ -278,7 +278,7 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. contract_ind = tuple(range(0, len(norm_axis))) output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.shard_mode == ShardMode.EXPLICIT else None ) @@ -351,16 +351,16 @@ def __init__( if self.config.shard_exp_on_fsdp: # special sharding for dsv3 - self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp") - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", None, "mlp_moe") + self.wo_kernel_axes = ("embed_moe", "mlp_moe", None) elif self.config.use_2d_fsdp_sharding: - self.wi_kernel_axes = ("embed_no_exp_moe", "mlp", None) - self.wo_kernel_axes = ("embed_no_exp_moe", "mlp", None) + self.wi_kernel_axes = ("embed_moe", "mlp_moe", None) + self.wo_kernel_axes = ("embed_moe", "mlp_moe", None) elif self.config.use_batch_split_schedule: self.wi_kernel_axes, self.wo_kernel_axes = get_batchsplit_init_kernel_axes() else: - self.wi_kernel_axes = ("exp", "embed_no_exp_moe", "mlp") - self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp_moe") + self.wi_kernel_axes = ("exp", "embed_moe", "mlp_moe") + self.wo_kernel_axes = ("exp", "mlp_moe", "embed_moe") if self.config.attention == "vllm_rpa": # vLLM uses 'model' as the tensor parallelism axis name @@ -437,7 +437,7 @@ def __init__( if self.config.mlp_bias: wi_bias_axes = ("exp", "activation_mlp") - wo_bias_axes = ("exp", "activation_embed_moe") + wo_bias_axes = ("exp", "activation_embed") wi_bias_shape = (self.num_experts, self.intermediate_dim) wo_bias_shape = (self.num_experts, self.config.emb_dim) self.wi_0_bias = nnx.Param( @@ -1020,48 +1020,24 @@ def gmm( output = output[: hs_shape[0]] return output - # Currently, we support data, tensor, and expert parallelism with Megablox. - # We all gather the input activations over tensor parallelism to follow - # https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. - - # Check if the batch should be sharded by expert and whether the batch_size - # supports this. For example, for interleaved inference, prefill always has - # batch_size=1 while decode can have batch_size > 1. - try: - is_batch_sharded_by_expert = ( - self._expert_parallelism_name - in tuple( - filter( - lambda tup: tup[0] == "activation_batch_moe", - self.config.logical_axis_rules, - ) - )[ - 0 - ][1] - ) - except: # pylint: disable=bare-except - is_batch_sharded_by_expert = False - if is_batch_sharded_by_expert and inputs.shape[0] > 1: - batch_logical_axis = "activation_batch_moe" - else: - batch_logical_axis = "activation_batch_no_exp_moe" + batch_logical_axis = "activation_batch" if self.get_tensor_transpose_parallelism_size() > 1: input_partition_pspec = self._logical_to_mesh_axes( - (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + (batch_logical_axis, "activation_norm_length", "activation_embed") ) w0_bias_pspec = self._logical_to_mesh_axes(("exp", None)) w1_bias_pspec = self._logical_to_mesh_axes(("exp", None)) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) else: - input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp")) - wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe")) + wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed")) - gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + gate_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None)) + pre_bias_logits_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None)) else: # pre_bias_logits is None for non-DeepSeek v3 models pre_bias_logits_pspec = None @@ -1113,7 +1089,7 @@ def gmm( P(), # Replicate the input key ), out_specs=( - self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")), + self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed")), P(), # Handle None or replicate the output P(), # Handle None or replicate the output ), @@ -1159,58 +1135,47 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r ) if num_expert_parallelism > 1: - batch_axis = self._expert_parallelism_name if is_batch_sharded_by_expert else "data" # get group sizes for all shards local_expert_size = self.config.num_experts // num_expert_parallelism reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1) global_group_sizes = group_sizes - if is_batch_sharded_by_expert: - all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=batch_axis) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - all_shards_group_sizes, - expert_shard_id, - num_expert_parallelism, - ) - # TODO(ranran): For better performance, we could update output buffer to a smaller - # size to replace self.get_expert_parallelism_size() for efficiency, - # Or we could apply capacity_factor for excessive experts. - # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. - - # In the worst case, all of the global input data is assigned to each expert in the current shard. - # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if - # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. - max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) - buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) - - x = jax.lax.ragged_all_to_all( - x, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes, - local_expert_size, - shard_index=expert_shard_id, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) - else: - x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( - x, - global_group_sizes[None, :], - local_expert_size, - shard_index=expert_shard_id, - is_offset=True, - global_sorted_experts=selected_experts, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - ) + all_shards_group_sizes = jax.lax.all_gather(reshaped_group_sizes, axis_name=self._expert_parallelism_name) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + all_shards_group_sizes, + expert_shard_id, + num_expert_parallelism, + ) + + # TODO(ranran): For better performance, we could update output buffer to a smaller + # size to replace self.get_expert_parallelism_size() for efficiency, + # Or we could apply capacity_factor for excessive experts. + # Note: Reducing buffer increase the risk of token dropping under unbalanced distribution. + + # In the worst case, all of the global input data is assigned to each expert in the current shard. + # This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if + # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. + max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) + buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) + output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) + + x = jax.lax.ragged_all_to_all( + x, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + global_group_sizes = jax.lax.all_gather(group_sizes, axis_name=self._expert_parallelism_name) + x, local_sorted_indices, group_sizes, selected_experts = RoutedMoE.local_permute( + x, + global_group_sizes, + local_expert_size, + shard_index=expert_shard_id, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + ) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(selected_experts, w0_bias, w1_bias, wo_bias) @@ -1352,47 +1317,27 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ), dtype=intermediate_output.dtype, ) - if is_batch_sharded_by_expert: - # locally unpermute back to the original order - local_output = _sort_activations( - intermediate_output, - jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable - self.config.use_custom_sort_vjp, - ) - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - ) - intermediate_output = jax.lax.ragged_all_to_all( - local_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - else: - # If bach is replicated across EP shards then each shard should send - # 0..local_shard_size data to the other shards and receive the - # local_shard data from all of the other shards using - # ragged_all_to_all. - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - reshaped_group_sizes, # pylint: disable=undefined-variable - expert_shard_id, - num_expert_parallelism, - is_batch_sharded=False, - ) - intermediate_output = jax.lax.ragged_all_to_all( - intermediate_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) + + # locally unpermute back to the original order + local_output = _sort_activations( + intermediate_output, + jnp.argsort(local_sorted_indices), # pylint: disable=undefined-variable + self.config.use_custom_sort_vjp, + ) + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + jnp.transpose(all_shards_group_sizes), # pylint: disable=undefined-variable + expert_shard_id, + num_expert_parallelism, + ) + intermediate_output = jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) output = self.unpermute( intermediate_output, @@ -1425,13 +1370,13 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): wo_kernel = self._maybe_shard_with_logical(wo_kernel, ("exp_with_fsdp", "mlp_no_fsdp", "embed_tensor_transpose")) if self.get_tensor_transpose_parallelism_size() > 1: - input_axes = (batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe") + input_axes = (batch_logical_axis, "activation_norm_length", "activation_embed") else: - input_axes = (batch_logical_axis, "activation_norm_length_moe", None) + input_axes = (batch_logical_axis, "activation_norm_length", None) - gate_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + gate_logits_axes = (batch_logical_axis, "activation_norm_length", None) if self.config.model_name.startswith("deepseek3"): - pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length_moe", None) + pre_bias_logits_axes = (batch_logical_axis, "activation_norm_length", None) else: pre_bias_logits_axes = None @@ -1449,14 +1394,12 @@ def reshape_and_update_weights(self, weights, indices): # output of updated weights: (batch_size, seq_len, num_experts) update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype) index_update = ( - self._maybe_shard_with_logical( - jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None) - ), - self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_moe", None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch", None, None)), + self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length", None)), indices, ) weight_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_moe", None)) + create_sharding(self.mesh, ("activation_batch", "activation_length", None)) if self.config.shard_mode == ShardMode.EXPLICIT else None ) @@ -1511,7 +1454,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): expert_mask, (batch_size, cp, sub_seq * self.num_experts_per_tok, self.num_experts), ) - expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch_moe", None, None, None)) + expert_mask_fused = self._maybe_shard_with_logical(expert_mask_fused, ("activation_batch", None, None, None)) expert_token_count_fused = jnp.cumsum(expert_mask_fused, axis=2) expert_token_count = jnp.reshape( expert_token_count_fused, @@ -1519,7 +1462,7 @@ def generate_masks_subgroup(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None, None), + ("activation_batch", "activation_norm_length", None, None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=3) @@ -1607,7 +1550,7 @@ def generate_masks(self, top_k_indices, softmax_probs): ) expert_token_count = self._maybe_shard_with_logical( expert_token_count, - ("activation_batch_moe", "activation_norm_length_moe", None, None), + ("activation_batch", "activation_norm_length", None, None), ) trunc_expert_mask = expert_mask * jnp.less_equal(expert_token_count, expert_capacity_per_batch) combined_expert_mask = jnp.sum(trunc_expert_mask, axis=2) @@ -1752,13 +1695,13 @@ def dense_matmul( mask_axes = ("activation_batch_moe", "activation_norm_length_moe", None, None) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_mlp", ) @@ -1787,14 +1730,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1815,14 +1758,14 @@ def dense_matmul( ) dispatch_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_embed_moe", ) mlp_axis = ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, None, "activation_mlp", @@ -1848,7 +1791,7 @@ def dense_matmul( dispatch, ( None, - "activation_batch_no_exp_moe", + "activation_batch_moe", "activation_norm_length_moe", None, "activation_embed_moe", @@ -1911,7 +1854,7 @@ def dense_matmul( intermediate_layer, ( "activation_exp", - "activation_batch_no_exp_moe", + "activation_batch_moe", None, "activation_embed_moe", ), diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json index 0ced56c01f..2c4505db84 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json index 1fd6ceb6fd..2c9e622ff4 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json index 950b64159c..0d4e78849a 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json index 5b2ab94daf..c0112a9795 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json index f91f7f18a5..a4ec8341ba 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[96,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[96,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json index 72cbbdea66..521ab4bfb8 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json index 4f601415fb..01d318e884 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[384,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[384,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json index 65120bac91..390ea29c8e 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json index 0ced56c01f..2c4505db84 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[192,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json index 1fd6ceb6fd..2c9e622ff4 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json index 950b64159c..0d4e78849a 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json @@ -122,13 +122,13 @@ }, { "moe/inputs: bfloat16[768,2048,2048]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,64]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json index 5e33fc22b8..5cfea0ee37 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json @@ -157,8 +157,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -171,8 +171,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -185,8 +185,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -483,8 +483,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -497,8 +497,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -511,8 +511,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, @@ -805,8 +805,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -819,8 +819,8 @@ "partition_spec": [ "exp", "moe_layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 64, @@ -833,8 +833,8 @@ "partition_spec": [ "exp", "moe_layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 64, diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json index 5b2ab94daf..c0112a9795 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json @@ -692,8 +692,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -2140,8 +2139,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null @@ -3552,8 +3550,7 @@ "fsdp_transpose", "sequence", "tensor_transpose", - "context", - "expert" + "context" ], null, null diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json index 0e1840a5a8..82d1e6f2ee 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[192,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json index 81635e5b05..86c3a3071e 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[768,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json index 5012359e63..40d9a30c0d 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[96,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[96,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json index 8410b8088d..e6964d7ed0 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[384,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[384,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json index 0e1840a5a8..82d1e6f2ee 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[192,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } }, { "moe/gate_logits: bfloat16[192,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P('fsdp', None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json index 81635e5b05..86c3a3071e 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json @@ -56,13 +56,13 @@ }, { "moe/inputs: bfloat16[768,2048,2880]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } }, { "moe/gate_logits: bfloat16[768,2048,32]": { - "logic_axes": "('activation_batch_moe', 'activation_norm_length_moe', None)", + "logic_axes": "('activation_batch', 'activation_norm_length', None)", "PartitionSpec": "P(('data', 'fsdp'), None, None)" } } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json index 35b79ae83c..c944c8e273 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json @@ -149,8 +149,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -175,8 +175,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -201,8 +201,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -215,7 +215,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -381,8 +381,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -407,8 +407,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -433,8 +433,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -447,7 +447,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -645,8 +645,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -671,8 +671,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -697,8 +697,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -711,7 +711,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -877,8 +877,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -903,8 +903,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -929,8 +929,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -943,7 +943,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1137,8 +1137,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1163,8 +1163,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1189,8 +1189,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1203,7 +1203,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32, @@ -1369,8 +1369,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1395,8 +1395,8 @@ "partition_spec": [ "exp", "layers", - "embed_no_exp_moe", - "mlp" + "embed_moe", + "mlp_moe" ], "shape": [ 32, @@ -1421,8 +1421,8 @@ "partition_spec": [ "exp", "layers", - "mlp", - "embed_no_exp_moe" + "mlp_moe", + "embed_moe" ], "shape": [ 32, @@ -1435,7 +1435,7 @@ "partition_spec": [ "exp", "layers", - "activation_embed_moe" + "activation_embed" ], "shape": [ 32,