diff --git a/src/maxtext/common/common_types.py b/src/maxtext/common/common_types.py index 4f5b825c00..0e1ca70f2a 100644 --- a/src/maxtext/common/common_types.py +++ b/src/maxtext/common/common_types.py @@ -106,6 +106,7 @@ class DecoderBlockType(enum.Enum): QWEN2 = "qwen2" QWEN3 = "qwen3" QWEN3_MOE = "qwen3_moe" + QWEN3_CUSTOM_MOE = "qwen3_custom_moe" QWEN3_NEXT = "qwen3_next" GPT3 = "gpt3" GPT_OSS = "gpt_oss" diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 0b335608c4..5eed05cb11 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -160,8 +160,10 @@ base_emb_dim: 2048 base_num_query_heads: 16 base_num_kv_heads: 16 base_mlp_dim: 7168 +dense_init_scale: 1.0 base_num_decoder_layers: 16 head_dim: 128 +attention_output_dim: -1 # Those parameters are only used with global attention for Gemma4. global_head_dim: 0 global_num_kv_heads: 0 @@ -195,6 +197,7 @@ num_experts_per_tok: 1 megablox: true sparse_matmul: true capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default +moe_expert_input_dim: -1 # feature dimension of the tokens entering the MoE expert blocks. load_balance_loss_weight: 0.0 # weight for the load balance loss use_random_routing: false # whether to use random routing for debug/test purpose use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul diff --git a/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml b/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml new file mode 100644 index 0000000000..93e2fb1b3b --- /dev/null +++ b/src/maxtext/configs/models/qwen3-custom-30b-a3b.yml @@ -0,0 +1,42 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Model config for custom Qwen3-30B-A3B + +# Core Architectural Parameters +decoder_block: "qwen3_custom_moe" +base_emb_dim: 2048 +base_mlp_dim: 2048 +base_num_query_heads: 16 +base_num_kv_heads: 2 +base_num_decoder_layers: 48 +head_dim: 256 +mlp_activations: ["silu", "linear"] +vocab_size: 151936 +normalization_layer_epsilon: 1.0e-6 +use_qk_norm: True +attention_output_dim: 768 +moe_expert_input_dim: 768 + +# MoE Specific Parameters +num_experts: 128 +num_experts_per_tok: 8 +base_moe_mlp_dim: 2048 +norm_topk_prob: true + +# RoPE Settings +rope_max_timescale: 10_000_000 + +# General Model Settings +enable_dropout: False diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 315758015a..9a71a27881 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -256,6 +256,7 @@ class ProfilerType(str, Enum): "qwen3-480b-a35b", "qwen3-next-80b-a3b", "qwen3-omni-30b-a3b", + "qwen3-custom-30b-a3b", "gpt3-175b", "gpt3-22b", "gpt3-6b", @@ -267,6 +268,7 @@ class ProfilerType(str, Enum): "olmo3-7b", "olmo3-7b-pt", "olmo3-32b", + "qwen3-custom-moe", ] @@ -447,11 +449,13 @@ class ModelArchitecture(BaseModel): base_num_query_heads: int = Field(16, description="Base number of query heads.") base_num_kv_heads: int = Field(16, description="Base number of key/value heads.") base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.") + dense_init_scale: float = Field(1.0, description="Initialization scale for dense layers") base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.") head_dim: int = Field( 128, description="Model query and key head dimension.", ) + attention_output_dim: int = Field(-1, description="Override output dimension for attention block if set to a positive value.") global_head_dim: int = Field( 0, description="Model query and key head dimension for global attention layers.", @@ -646,6 +650,7 @@ class MoEGeneral(BaseModel): num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.") num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.") capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.") + moe_expert_input_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.") load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.") use_custom_sort_vjp: bool = Field( True, @@ -2802,4 +2807,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de else: self.constant_bound_config = [] + if self.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE: + if self.moe_expert_input_dim != self.attention_output_dim: + raise ValueError( + f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) " + f"must be equal to attention_output_dim ({self.attention_output_dim})" + ) + return self diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 7932961182..b80f9115df 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -55,6 +55,7 @@ olmo3, qwen2, qwen3, + qwen3_custom, simple_layer, ) from maxtext.multimodal import utils as mm_utils @@ -475,6 +476,8 @@ def get_decoder_layers(self): return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: return [qwen3.Qwen3MoeDecoderLayerToLinen] + case DecoderBlockType.QWEN3_CUSTOM_MOE: + return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen] case DecoderBlockType.QWEN3_NEXT: return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen] case DecoderBlockType.SIMPLE: @@ -533,6 +536,7 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.QWEN2, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, + DecoderBlockType.QWEN3_CUSTOM_MOE, DecoderBlockType.GPT_OSS, DecoderBlockType.SIMPLE, DecoderBlockType.SIMPLE_MLP, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 7cc2227722..51d945cb25 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -349,6 +349,10 @@ def __init__( self.quant = quant self.rngs = rngs + self.moe_expert_input_dim = getattr(self.config, "moe_expert_input_dim", -1) + if self.moe_expert_input_dim <= 0: + self.moe_expert_input_dim = self.config.emb_dim + if self.config.shard_exp_on_fsdp: # special sharding for dsv3 self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp") @@ -374,7 +378,7 @@ def __init__( self._expert_parallelism_name = "expert" self.gate = GateLogit( - in_features_shape=self.config.emb_dim, + in_features_shape=self.moe_expert_input_dim, out_features_shape=self.num_experts, mesh=self.mesh, model_name=self.config.model_name, @@ -400,14 +404,14 @@ def __init__( # During aqt convert state we delete kernel weight from params to save # memory. Instead they are retrieved from the tensors stored in the 'aqt' # collection. - self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) - self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim)) - self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim)) + self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim)) + self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim)) + self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim)) else: self.wi_0 = nnx.Param( self.kernel_init( self.rngs.params(), - (num_experts, self.config.emb_dim, intermediate_dim), + (num_experts, self.moe_expert_input_dim, intermediate_dim), weight_dtype, kernel_in_axis, kernel_out_axis, @@ -417,7 +421,7 @@ def __init__( self.wi_1 = nnx.Param( self.kernel_init( self.rngs.params(), - (num_experts, self.config.emb_dim, intermediate_dim), + (num_experts, self.moe_expert_input_dim, intermediate_dim), weight_dtype, kernel_in_axis, kernel_out_axis, @@ -427,7 +431,7 @@ def __init__( self.wo = nnx.Param( self.kernel_init( self.rngs.params(), - (self.num_experts, self.intermediate_dim, self.config.emb_dim), + (self.num_experts, self.intermediate_dim, self.moe_expert_input_dim), self.weight_dtype, kernel_in_axis, kernel_out_axis, @@ -439,7 +443,7 @@ def __init__( wi_bias_axes = ("exp", "activation_mlp") wo_bias_axes = ("exp", "activation_embed_moe") wi_bias_shape = (self.num_experts, self.intermediate_dim) - wo_bias_shape = (self.num_experts, self.config.emb_dim) + wo_bias_shape = (self.num_experts, self.moe_expert_input_dim) self.wi_0_bias = nnx.Param( default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype), sharding=wi_bias_axes, @@ -1182,7 +1186,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r # experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs. max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok) buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok) - output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype) + output_shape = jax.lax.empty((buffer_size, self.moe_model_dim), dtype=x.dtype) x = jax.lax.ragged_all_to_all( x, @@ -1337,7 +1341,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ) # Sum up the partial outputs across the expert shards. - output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size())) + output = jnp.reshape(output, (-1, sequence_length, self.moe_model_dim // self.get_tensor_parallelism_size())) output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) else: @@ -1348,7 +1352,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): output_shape = jax.lax.empty( ( original_inputs_first_dim, - self.config.emb_dim // self.get_tensor_parallelism_size(), + self.moe_model_dim // self.get_tensor_parallelism_size(), ), dtype=intermediate_output.dtype, ) @@ -2095,6 +2099,10 @@ def __init__( self.dtype = dtype self.quant = quant self.rngs = rngs + self.moe_model_dim = getattr(self.config, "moe_model_dim", -1) + if self.moe_model_dim <= 0: + self.moe_model_dim = self.config.emb_dim + # NOTE: the name MoeBlock_0 is to ensure reverse compatibility with # existing checkpoints for routed experts. self.MoeBlock_0 = RoutedMoE( @@ -2116,7 +2124,7 @@ def __init__( ) self.shared_experts = linears.MlpBlock( mesh=self.mesh, - in_features=self.config.emb_dim, + in_features=self.moe_model_dim, intermediate_dim=self.config.shared_experts * shared_expert_mlp_dim, activations=self.config.mlp_activations, intermediate_dropout_rate=self.config.dropout_rate, diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index 6d502d92c4..ea7ee4266b 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -415,7 +415,7 @@ def __init__( self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE( config=self.config, mesh=mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, diff --git a/src/maxtext/models/gemma4.py b/src/maxtext/models/gemma4.py index 9e8c145e84..1803ec705c 100644 --- a/src/maxtext/models/gemma4.py +++ b/src/maxtext/models/gemma4.py @@ -70,7 +70,7 @@ def __init__( self.moe_block = moe.RoutedAndSharedMoE( config=config, mesh=mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), weight_dtype=config.weight_dtype, dtype=config.dtype, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 58a0a2db8f..4dfde74dd6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -121,7 +121,7 @@ def __init__( num_experts=config.num_experts, num_experts_per_tok=config.num_experts_per_tok, mesh=mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), intermediate_dim=config.mlp_dim, dtype=config.dtype, diff --git a/src/maxtext/models/llama4.py b/src/maxtext/models/llama4.py index c66e80440b..224ae6a3e1 100644 --- a/src/maxtext/models/llama4.py +++ b/src/maxtext/models/llama4.py @@ -403,7 +403,7 @@ def __init__( self.Llama4MoEBlock_0 = RoutedAndSharedMoE( config=config, mesh=self.mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), dtype=config.dtype, weight_dtype=config.weight_dtype, diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index 46441096d5..faf69273c6 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -110,7 +110,7 @@ def __init__( num_experts=config.num_experts, num_experts_per_tok=config.num_experts_per_tok, mesh=mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), intermediate_dim=config.mlp_dim, dtype=config.dtype, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bc7d5fdfc1..3eb16c4285 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -785,7 +785,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, mesh=self.mesh, - kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=max_initializers.nd_dense_init(cfg.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), intermediate_dim=cfg.moe_mlp_dim, dtype=cfg.dtype, @@ -815,7 +815,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn out_features_shape=1, use_bias=False, # Qwen3-Next shared_expert_gate does not have a bias dtype=cfg.dtype, - kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=max_initializers.nd_dense_init(cfg.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), matmul_precision=cfg.matmul_precision, rngs=rngs, @@ -1261,7 +1261,7 @@ def __init__( num_experts=config.num_experts, num_experts_per_tok=config.num_experts_per_tok, mesh=mesh, - kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=max_initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"), kernel_axes=("embed", None), intermediate_dim=config.moe_mlp_dim, # same as config.mlp_dim dtype=config.dtype, @@ -1923,7 +1923,7 @@ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): in_features=self.config.d_model_for_audio, intermediate_dim=self.config.encoder_ffn_dim_for_audio, activations=("gelu",), # Single GELU activation - kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_init=max_initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), intermediate_dropout_rate=0.0, # No dropout to match AudioMLP dtype=self.config.dtype_mm, weight_dtype=self.config.weight_dtype, @@ -2039,7 +2039,7 @@ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): use_bias=False, dtype=self.config.dtype_mm, weight_dtype=self.config.weight_dtype, - kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + kernel_init=nd_dense_init(self.config.dense_init_scale, "fan_in", "normal"), matmul_precision=self.config.matmul_precision, rngs=self.rngs, ) @@ -2130,7 +2130,7 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): use_bias=True, dtype=config.dtype_mm, weight_dtype=config.weight_dtype, - kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + kernel_init=nd_dense_init(config.dense_init_scale, "fan_in", "normal"), matmul_precision=config.matmul_precision, rngs=rngs, ) @@ -2141,7 +2141,7 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs = None): use_bias=True, dtype=config.dtype_mm, weight_dtype=config.weight_dtype, - kernel_init=nd_dense_init(1.0, "fan_in", "normal"), + kernel_init=nd_dense_init(config.dense_init_scale, "fan_in", "normal"), matmul_precision=config.matmul_precision, rngs=rngs, ) diff --git a/src/maxtext/models/qwen3_custom.py b/src/maxtext/models/qwen3_custom.py new file mode 100644 index 0000000000..c3e40c1a81 --- /dev/null +++ b/src/maxtext/models/qwen3_custom.py @@ -0,0 +1,239 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# limitations under the License. +"""Custom Qwen3 model decoder layer.""" + +from typing import Any + +from jax.sharding import Mesh +import jax.numpy as jnp + +from flax import linen as nn +from flax import nnx +from jax.ad_checkpoint import checkpoint_name + +from maxtext.common.common_types import Config +from maxtext.layers import initializers as max_initializers +from maxtext.layers import moe +from maxtext.layers import nnx_wrappers +from maxtext.layers import quantizations +from maxtext.layers.quantizations import AqtQuantization as Quant +from maxtext.layers.attentions import Attention +from maxtext.layers.linears import DenseGeneral +from maxtext.utils import max_utils +from maxtext.utils.sharding import create_sharding +from maxtext.inference import page_manager +from maxtext.models.qwen3 import AttentionWithNorm + + +class Qwen3CustomAttention(Attention): + """Custom GQA attention that supports sub-dimensional output.""" + + def init_out_w(self, output_dim: int) -> nnx.Module: + """Initializes the output projection.""" + if not self.config.attention_output_dim > 0: + raise ValueError( + "attention_output_dim must be set to a positive integer for CustomAttention." + ) + + in_features = (self.num_query_heads, self.head_dim) + out_kernel_axis = ( + (None, None, None) + if self.config.ici_context_autoregressive_parallelism > 1 + else ("heads", "kv", "embed") + ) + axis = (-2, -1) + + return DenseGeneral( + in_features_shape=in_features, + out_features_shape=self.config.attention_output_dim, + axis=axis, + kernel_init=self.kernel_init, + kernel_axes=out_kernel_axis, # trade speed with memory + dtype=self.dtype, + weight_dtype=self.weight_dtype, + quant=self.quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=self.use_bias_in_projections, + rngs=self.rngs, + ) + + +class Qwen3CustomMoeDecoderLayer(AttentionWithNorm): + """Qwen3 Transformer decoder layer (Custom MoE).""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant, + rngs: nnx.Rngs, + ): + super().__init__(config, mesh, model_mode, quant, rngs) + + query_pre_attn_scalar = config.head_dim**-0.5 + batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config, model_mode) + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + # Override self_attention with Qwen3CustomAttention + self.self_attention = Qwen3CustomAttention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=quant, + kv_quant=quantizations.configure_kv_quant(config), + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + use_qk_norm=config.use_qk_norm, + query_pre_attn_scalar=query_pre_attn_scalar, + model_mode=model_mode, + use_mrope=config.use_mrope, + mrope_section=config.mrope_section, + rngs=rngs, + ) + + if config.attention_output_dim <= 0 or config.attention_output_dim != config.moe_expert_input_dim: + raise ValueError("attention_output_dim must be positive and equal to moe_expert_input_dim for Qwen3CustomMoeDecoderLayer.") + + self.moe_block = moe.RoutedAndSharedMoE( + config=self.config, + mesh=mesh, + kernel_init=max_initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + rngs=rngs, + ) + + if ( + self.config.attention_output_dim > 0 + and self.config.attention_output_dim != self.config.emb_dim + ): + out_kernel_axis = ( + (None, None) if self.config.ici_context_autoregressive_parallelism > 1 else ("mlp", "embed") + ) + self.layer_up_projection = DenseGeneral( + in_features_shape=self.config.attention_output_dim, + out_features_shape=self.config.emb_dim, + axis=-1, + kernel_init=max_initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"), + kernel_axes=out_kernel_axis, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + quant=quant, + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + use_bias=False, + rngs=rngs, + ) + else: + self.layer_up_projection = None + + self.mlp_intermediate_sharding = create_sharding( + self.mesh, ("activation_batch", "activation_norm_length", "activation_mlp") + ) + self.out_sharding = create_sharding(self.mesh, self.activation_axis_names) + + def apply_attention_with_norm( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + inputs = nn.with_logical_constraint(inputs, self.activation_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + lnx = self.pre_self_attention_layer_norm(inputs) + lnx = nn.with_logical_constraint(lnx, self.activation_axis_names) + attention_lnx, kv_cache = self.self_attention( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = nn.with_logical_constraint(attention_lnx, self.activation_axis_names) + return inputs, attention_lnx, kv_cache + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: None | jnp.ndarray, + decoder_positions: None | jnp.ndarray, + deterministic: bool, + model_mode: str, + previous_chunk=None, + page_state: None | page_manager.PageState = None, + slot: None | int = None, + kv_cache: None | jnp.ndarray = None, + attention_metadata: None | dict[str, Any] = None, + ): + if isinstance(inputs, tuple): + inputs = inputs[0] + + inputs, attention_lnx, kv_cache = self.apply_attention_with_norm( + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + + mlp_lnx, load_balance_loss, _ = self.moe_block( + attention_lnx, intermediate_sharding=self.mlp_intermediate_sharding, out_sharding=self.out_sharding + ) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) + + if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: + self.sow("intermediates", "moe_lb_loss", load_balance_loss) + + layer_output = mlp_lnx + if self.layer_up_projection is not None: + layer_output = self.layer_up_projection(layer_output) + layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) + + layer_output = inputs + layer_output + hidden_states = self.post_self_attention_layer_norm(layer_output) + hidden_states = nn.with_logical_constraint(hidden_states, self.activation_axis_names) + + if self.config.scan_layers: + return hidden_states, None + else: + return hidden_states, kv_cache + + +Qwen3CustomMoeDecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3CustomMoeDecoderLayer, + base_metadata_fn=max_initializers.variable_to_logically_partitioned, +) diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 3ed0b40c14..7695a4f25b 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -94,6 +94,41 @@ def test_tiny_config_no_scan(self): ] ) + def test_qwen3_custom_moe_config(self): + test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable + train_main( + [ + None, + get_test_config_path(), + "model_name=qwen3-custom-30b-a3b", + # pylint: disable=f-string-without-interpolation + f"base_output_directory={self.base_output_directory}", + "run_name=runner_test", + r"dataset_path={self.dataset_path}", + "base_emb_dim=256", + "attention_output_dim=256", + "moe_expert_input_dim=256", + "base_mlp_dim=256", + "base_moe_mlp_dim=256", + "head_dim=128", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "num_experts=4", # Reduced from 128 + "num_experts_per_tok=2", # Reduced from 8 + "base_num_decoder_layers=2", + "per_device_batch_size=2", + "max_target_length=128", + "dataset_type=synthetic", + "steps=2", + "enable_checkpointing=False", + rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", + "enable_goodput_recording=False", + "enable_checkpoint_cloud_logger=False", + "monitor_goodput=False", + "scan_layers=False", + ] + ) + def test_tiny_config_explicit_shardmode(self): test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable train_main(