Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class DecoderBlockType(enum.Enum):
QWEN2 = "qwen2"
QWEN3 = "qwen3"
QWEN3_MOE = "qwen3_moe"
QWEN3_CUSTOM_MOE = "qwen3_custom_moe"
QWEN3_NEXT = "qwen3_next"
GPT3 = "gpt3"
GPT_OSS = "gpt_oss"
Expand Down
3 changes: 3 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ base_emb_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 7168
dense_init_scale: 1.0
base_num_decoder_layers: 16
head_dim: 128
attention_output_dim: -1
# Those parameters are only used with global attention for Gemma4.
global_head_dim: 0
global_num_kv_heads: 0
Expand Down Expand Up @@ -195,6 +197,7 @@ num_experts_per_tok: 1
megablox: true
sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
moe_expert_input_dim: -1 # feature dimension of the tokens entering the MoE expert blocks.
load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
Expand Down
42 changes: 42 additions & 0 deletions src/maxtext/configs/models/qwen3-custom-30b-a3b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model config for custom Qwen3-30B-A3B

# Core Architectural Parameters
decoder_block: "qwen3_custom_moe"
base_emb_dim: 2048
base_mlp_dim: 2048
base_num_query_heads: 16
base_num_kv_heads: 2
base_num_decoder_layers: 48
head_dim: 256
mlp_activations: ["silu", "linear"]
vocab_size: 151936
normalization_layer_epsilon: 1.0e-6
use_qk_norm: True
attention_output_dim: 768
moe_expert_input_dim: 768

# MoE Specific Parameters
num_experts: 128
num_experts_per_tok: 8
base_moe_mlp_dim: 2048
norm_topk_prob: true

# RoPE Settings
rope_max_timescale: 10_000_000

# General Model Settings
enable_dropout: False
12 changes: 12 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ class ProfilerType(str, Enum):
"qwen3-480b-a35b",
"qwen3-next-80b-a3b",
"qwen3-omni-30b-a3b",
"qwen3-custom-30b-a3b",
"gpt3-175b",
"gpt3-22b",
"gpt3-6b",
Expand All @@ -267,6 +268,7 @@ class ProfilerType(str, Enum):
"olmo3-7b",
"olmo3-7b-pt",
"olmo3-32b",
"qwen3-custom-moe",
]


Expand Down Expand Up @@ -447,11 +449,13 @@ class ModelArchitecture(BaseModel):
base_num_query_heads: int = Field(16, description="Base number of query heads.")
base_num_kv_heads: int = Field(16, description="Base number of key/value heads.")
base_mlp_dim: int = Field(7168, description="Base dimension of the MLP layer.")
dense_init_scale: float = Field(1.0, description="Initialization scale for dense layers")
base_num_decoder_layers: int = Field(16, description="Base number of decoder layers.")
head_dim: int = Field(
128,
description="Model query and key head dimension.",
)
attention_output_dim: int = Field(-1, description="Override output dimension for attention block if set to a positive value.")
global_head_dim: int = Field(
0,
description="Model query and key head dimension for global attention layers.",
Expand Down Expand Up @@ -646,6 +650,7 @@ class MoEGeneral(BaseModel):
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
moe_expert_input_dim: int = Field(-1, description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
True,
Expand Down Expand Up @@ -2802,4 +2807,11 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
else:
self.constant_bound_config = []

if self.decoder_block == DecoderBlockType.QWEN3_CUSTOM_MOE:
if self.moe_expert_input_dim != self.attention_output_dim:
raise ValueError(
f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) "
f"must be equal to attention_output_dim ({self.attention_output_dim})"
)

return self
4 changes: 4 additions & 0 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
olmo3,
qwen2,
qwen3,
qwen3_custom,
simple_layer,
)
from maxtext.multimodal import utils as mm_utils
Expand Down Expand Up @@ -475,6 +476,8 @@ def get_decoder_layers(self):
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3_MOE:
return [qwen3.Qwen3MoeDecoderLayerToLinen]
case DecoderBlockType.QWEN3_CUSTOM_MOE:
return [qwen3_custom.Qwen3CustomMoeDecoderLayerToLinen]
case DecoderBlockType.QWEN3_NEXT:
return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen]
case DecoderBlockType.SIMPLE:
Expand Down Expand Up @@ -533,6 +536,7 @@ def get_norm_layer(self, num_features: int):
DecoderBlockType.QWEN2,
DecoderBlockType.QWEN3,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.QWEN3_CUSTOM_MOE,
DecoderBlockType.GPT_OSS,
DecoderBlockType.SIMPLE,
DecoderBlockType.SIMPLE_MLP,
Expand Down
32 changes: 20 additions & 12 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ def __init__(
self.quant = quant
self.rngs = rngs

self.moe_expert_input_dim = getattr(self.config, "moe_expert_input_dim", -1)
if self.moe_expert_input_dim <= 0:
self.moe_expert_input_dim = self.config.emb_dim

if self.config.shard_exp_on_fsdp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_no_exp_moe", None, "mlp")
Expand All @@ -374,7 +378,7 @@ def __init__(
self._expert_parallelism_name = "expert"

self.gate = GateLogit(
in_features_shape=self.config.emb_dim,
in_features_shape=self.moe_expert_input_dim,
out_features_shape=self.num_experts,
mesh=self.mesh,
model_name=self.config.model_name,
Expand All @@ -400,14 +404,14 @@ def __init__(
# During aqt convert state we delete kernel weight from params to save
# memory. Instead they are retrieved from the tensors stored in the 'aqt'
# collection.
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
else:
self.wi_0 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.config.emb_dim, intermediate_dim),
(num_experts, self.moe_expert_input_dim, intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand All @@ -417,7 +421,7 @@ def __init__(
self.wi_1 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.config.emb_dim, intermediate_dim),
(num_experts, self.moe_expert_input_dim, intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand All @@ -427,7 +431,7 @@ def __init__(
self.wo = nnx.Param(
self.kernel_init(
self.rngs.params(),
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
(self.num_experts, self.intermediate_dim, self.moe_expert_input_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand All @@ -439,7 +443,7 @@ def __init__(
wi_bias_axes = ("exp", "activation_mlp")
wo_bias_axes = ("exp", "activation_embed_moe")
wi_bias_shape = (self.num_experts, self.intermediate_dim)
wo_bias_shape = (self.num_experts, self.config.emb_dim)
wo_bias_shape = (self.num_experts, self.moe_expert_input_dim)
self.wi_0_bias = nnx.Param(
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
sharding=wi_bias_axes,
Expand Down Expand Up @@ -1182,7 +1186,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)
output_shape = jax.lax.empty((buffer_size, self.moe_model_dim), dtype=x.dtype)

x = jax.lax.ragged_all_to_all(
x,
Expand Down Expand Up @@ -1337,7 +1341,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
)

# Sum up the partial outputs across the expert shards.
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
output = jnp.reshape(output, (-1, sequence_length, self.moe_model_dim // self.get_tensor_parallelism_size()))
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)

else:
Expand All @@ -1348,7 +1352,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
output_shape = jax.lax.empty(
(
original_inputs_first_dim,
self.config.emb_dim // self.get_tensor_parallelism_size(),
self.moe_model_dim // self.get_tensor_parallelism_size(),
),
dtype=intermediate_output.dtype,
)
Expand Down Expand Up @@ -2095,6 +2099,10 @@ def __init__(
self.dtype = dtype
self.quant = quant
self.rngs = rngs
self.moe_model_dim = getattr(self.config, "moe_model_dim", -1)
if self.moe_model_dim <= 0:
self.moe_model_dim = self.config.emb_dim

# NOTE: the name MoeBlock_0 is to ensure reverse compatibility with
# existing checkpoints for routed experts.
self.MoeBlock_0 = RoutedMoE(
Expand All @@ -2116,7 +2124,7 @@ def __init__(
)
self.shared_experts = linears.MlpBlock(
mesh=self.mesh,
in_features=self.config.emb_dim,
in_features=self.moe_model_dim,
intermediate_dim=self.config.shared_experts * shared_expert_mlp_dim,
activations=self.config.mlp_activations,
intermediate_dropout_rate=self.config.dropout_rate,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def __init__(
self.DeepSeekMoeBlock_0 = moe.RoutedAndSharedMoE(
config=self.config,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.moe_block = moe.RoutedAndSharedMoE(
config=config,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
weight_dtype=config.weight_dtype,
dtype=config.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=config.mlp_dim,
dtype=config.dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def __init__(
self.Llama4MoEBlock_0 = RoutedAndSharedMoE(
config=config,
mesh=self.mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
dtype=config.dtype,
weight_dtype=config.weight_dtype,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=config.mlp_dim,
dtype=config.dtype,
Expand Down
14 changes: 7 additions & 7 deletions src/maxtext/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn
num_experts=cfg.num_experts,
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=self.mesh,
kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=max_initializers.nd_dense_init(cfg.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=cfg.moe_mlp_dim,
dtype=cfg.dtype,
Expand Down Expand Up @@ -815,7 +815,7 @@ def __init__(self, config: Config, mesh: Mesh, quant: None | Quant = None, *, rn
out_features_shape=1,
use_bias=False, # Qwen3-Next shared_expert_gate does not have a bias
dtype=cfg.dtype,
kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=max_initializers.nd_dense_init(cfg.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
matmul_precision=cfg.matmul_precision,
rngs=rngs,
Expand Down Expand Up @@ -1261,7 +1261,7 @@ def __init__(
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
mesh=mesh,
kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=max_initializers.nd_dense_init(config.dense_init_scale, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=config.moe_mlp_dim, # same as config.mlp_dim
dtype=config.dtype,
Expand Down Expand Up @@ -1923,7 +1923,7 @@ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
in_features=self.config.d_model_for_audio,
intermediate_dim=self.config.encoder_ffn_dim_for_audio,
activations=("gelu",), # Single GELU activation
kernel_init=max_initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_init=max_initializers.nd_dense_init(self.config.dense_init_scale, "fan_in", "truncated_normal"),
intermediate_dropout_rate=0.0, # No dropout to match AudioMLP
dtype=self.config.dtype_mm,
weight_dtype=self.config.weight_dtype,
Expand Down Expand Up @@ -2039,7 +2039,7 @@ def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
use_bias=False,
dtype=self.config.dtype_mm,
weight_dtype=self.config.weight_dtype,
kernel_init=nd_dense_init(1.0, "fan_in", "normal"),
kernel_init=nd_dense_init(self.config.dense_init_scale, "fan_in", "normal"),
matmul_precision=self.config.matmul_precision,
rngs=self.rngs,
)
Expand Down Expand Up @@ -2130,7 +2130,7 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
use_bias=True,
dtype=config.dtype_mm,
weight_dtype=config.weight_dtype,
kernel_init=nd_dense_init(1.0, "fan_in", "normal"),
kernel_init=nd_dense_init(config.dense_init_scale, "fan_in", "normal"),
matmul_precision=config.matmul_precision,
rngs=rngs,
)
Expand All @@ -2141,7 +2141,7 @@ def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
use_bias=True,
dtype=config.dtype_mm,
weight_dtype=config.weight_dtype,
kernel_init=nd_dense_init(1.0, "fan_in", "normal"),
kernel_init=nd_dense_init(config.dense_init_scale, "fan_in", "normal"),
matmul_precision=config.matmul_precision,
rngs=rngs,
)
Expand Down
Loading
Loading