Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 60 additions & 79 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from maxtext.layers import mhc
from maxtext.layers import normalizations
from maxtext.layers import pipeline
from maxtext.layers.nnx_decoders import NNXDecoderLayer, NNXSequentialPipelineStage, NNXScannedPipelineStage
from maxtext.layers import quantizations
from maxtext.layers.attentions import attention_as_linen
from maxtext.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen
Expand Down Expand Up @@ -227,47 +228,6 @@ def __call__(
return layer_output, kv_cache


class SequentialBlockDecoderLayers(nn.Module):
"""Sequential unscanned series of decoder layers."""

decoder_layer: Any
num_decoder_layers: int
config: Config
mesh: Mesh
quant: Quant
model_mode: str

@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
decoder_segment_ids,
decoder_positions,
deterministic: bool,
model_mode,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
slot=slot,
page_state=page_state,
)
if self.config.scan_layers:
inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None).
if self.config.scan_layers:
return inputs, None # pytype: disable=bad-return-type
else:
return inputs


def deepstack_process(hidden_states, bidirectional_mask, visual_embeds):
"""Process deepstack visual embeddings by adding them to hidden states at visual token positions.

Expand Down Expand Up @@ -306,10 +266,14 @@ def setup(self):
self.decoder_layer = self.get_decoder_layers()
self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim)
if self.config.using_pipeline_parallelism:
pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer)
nnx_blocks = self._get_nnx_decoder_block_classes()

def stage_factory(rngs):
return self._build_nnx_pipeline_stage(nnx_blocks, rngs)

remat_policy = self.get_remat_policy()
self.pipeline_module = pipeline.create_pipeline(
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
config=self.config, stage_factory=stage_factory, mesh=self.mesh, remat_policy=remat_policy
)

def minimal_policy(self, with_context=False, with_quantization=False):
Expand Down Expand Up @@ -491,6 +455,59 @@ def get_decoder_layers(self):
# Default case to handle any unknown decoder block types.
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")

def _get_nnx_decoder_block_classes(self):
"""Returns NNX decoder block classes for pipeline stage creation."""
cfg = self.config

def get_scannable(normal_cls, scannable_cls):
return [scannable_cls] if cfg.scan_layers else [normal_cls]

def get_deepseek():
if cfg.use_batch_split_schedule:
return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer]
return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]

layer_map = {
DecoderBlockType.DEFAULT: [NNXDecoderLayer],
DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer],
DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer],
DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer],
DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer],
DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer],
DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer],
DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock),
DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer],
DecoderBlockType.GPT_OSS: get_scannable(gpt_oss.GptOssDecoderLayer, gpt_oss.GptOssScannableBlock),
DecoderBlockType.QWEN2: [qwen2.Qwen2DecoderLayer],
DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer],
DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer],
DecoderBlockType.QWEN3_NEXT: get_scannable(qwen3.Qwen3NextDecoderLayer, qwen3.Qwen3NextScannableBlock),
DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer],
DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer],
DecoderBlockType.DEEPSEEK: get_deepseek(),
DecoderBlockType.LLAMA4: get_scannable(llama4.Llama4DecoderLayer, llama4.Llama4ScannableBlock),
DecoderBlockType.OLMO3: get_scannable(olmo3.Olmo3DecoderLayer, olmo3.Olmo3ScannableBlock),
}

if cfg.decoder_block not in layer_map:
raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}")
return layer_map[cfg.decoder_block]

def _build_nnx_pipeline_stage(self, decoder_blocks, rngs):
"""Creates a single NNX pipeline stage module."""
cfg = self.config
base_stage_cls = decoder_blocks[1] if cfg.decoder_block == DecoderBlockType.DEEPSEEK else decoder_blocks[0]

if cfg.num_layers_per_pipeline_stage == 1:
return base_stage_cls(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs)
elif cfg.scan_layers_per_stage:
return NNXScannedPipelineStage(
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
)
return NNXSequentialPipelineStage(
base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs
)

def set_remat_policy(self, block_layers, policy):
"""Set remat policy"""
RemattedBlockLayers = []
Expand Down Expand Up @@ -576,42 +593,6 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args
)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""

def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
else:
return blocks[0]

cfg = self.config
base_stage = get_layer_to_pipeline(decoder_blocks, cfg)
if cfg.set_remat_policy_on_layers_per_stage:
policy = self.get_remat_policy()
base_stage = self.set_remat_policy([base_stage], policy)[0]
if cfg.num_layers_per_pipeline_stage == 1:
stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode)
elif cfg.scan_layers_per_stage:
stage_module = self.scan_decoder_layers(
cfg,
base_stage,
cfg.num_layers_per_pipeline_stage,
"layers_per_stage",
self.mesh,
in_axes_tuple=(nn.broadcast,) * 4,
)
else:
stage_module = SequentialBlockDecoderLayers(
decoder_layer=base_stage,
num_decoder_layers=cfg.num_layers_per_pipeline_stage,
config=cfg,
mesh=self.mesh,
quant=self.quant,
model_mode=self.model_mode,
)
return stage_module

@nn.compact
def _apply_embedding(
self,
Expand Down
Loading