Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
609 changes: 609 additions & 0 deletions src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py

Large diffs are not rendered by default.

581 changes: 581 additions & 0 deletions src/maxtext/checkpoint_conversion/linen_nnx_converter.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"""

import argparse
import functools
import gc
import os
import sys
Expand Down Expand Up @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
mesh = Mesh(devices_array, cfg.mesh_axes)

quant = quantizations.configure_quantization(cfg)
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
if cfg.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand All @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
cfg.checkpoint_period,
)

state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager)
if cfg.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng)
state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn)
max_logging.log("start")
max_utils.print_mem_stats("After params initialized")

Expand Down
32 changes: 27 additions & 5 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import flags
import datetime
from etils import epath
from flax import nnx
from flax.training import train_state
import jax
from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE
Expand Down Expand Up @@ -532,7 +533,7 @@ def load_state_if_possible(
load_parameters_from_path: str,
load_full_state_from_path: str,
checkpoint_storage_concurrent_gb: int,
abstract_unboxed_pre_state: train_state.TrainState,
abstract_unboxed_pre_state: train_state.TrainState | nnx.State,
enable_single_replica_ckpt_restoring: bool | None = False,
dataset_type: str | None = "tfds",
step: int = -1, # -1 means latest
Expand Down Expand Up @@ -600,8 +601,13 @@ def map_to_pspec(data):
)
ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True)

restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state)
checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args)
# Convert nnx.State to pure dict to match how checkpoints are saved for NNX
restore_target = abstract_unboxed_pre_state
if isinstance(abstract_unboxed_pre_state, nnx.State):
restore_target = abstract_unboxed_pre_state.to_pure_dict()

restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target)
checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args)

match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
Expand Down Expand Up @@ -636,9 +642,14 @@ def map_to_pspec(data):
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)

if load_parameters_from_path != "":
if isinstance(abstract_unboxed_pre_state, nnx.State):
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
else:
params = abstract_unboxed_pre_state.params

restored_params = load_params_from_path(
load_parameters_from_path,
abstract_unboxed_pre_state.params,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
Expand Down Expand Up @@ -730,7 +741,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step
# Determine the effective step for saving a checkpoint.
# If 'step' is not provided, this call is for a potential final checkpoint
# and use the last completed step from the state.
actual_step = (int(state.step) - 1) if step is None else int(step)
if step is not None:
actual_step = int(step)
else:
if config.pure_nnx:
actual_step = int(state.optimizer.step) - 1
else:
# Linen TrainState has .step attribute
actual_step = int(state.step) - 1

if config.pure_nnx:
# Convert nnx.State to dict.
state = state.to_pure_dict()

# Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic.
# This occurs if this function was called:
Expand Down
9 changes: 7 additions & 2 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,13 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['layers_outside_pipeline', []],
['layers_per_stage', []],
['engram_dim', ['tensor']],
['mhc', []],
['diloco', 'diloco'],
['num_activations', []],
['circular_repeats', []],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
Expand Down Expand Up @@ -1126,8 +1130,9 @@ position_id_per_seconds: 25
subslice_shape: ""

# NNX
enable_nnx: false
pure_nnx_decoder: false
enable_nnx: True
pure_nnx_decoder: True
pure_nnx: True

################################## Qwen3-Next Specific Configs ##################################
# Kernel size for the 1D convolution in the Gated Delta Net
Expand Down
28 changes: 28 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,32 @@ logical_axis_rules: [
['exp_with_fsdp', 'fsdp'],
['paged_kv_heads', ['tensor']],
['engram_dim', ['tensor']],
# Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh
['activation_attn_length_no_exp', []],
['activation_length_no_exp', []],
['activation_norm_length', []],
['activation_q_length_no_exp', []],
['prefill_activation_length', []],
['prefill_activation_norm_length', []],
['activation_kv_length', []],
['decode_length', []],
['embed_tensor_transpose', []],
['q_lora_up_proj', []],
['kv_lora_up_proj', []],
['kv', []],
['qkv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_kv', []],
['cache_sequence', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['num_activations', []],
['mhc', []],
['diloco', []],
]
53 changes: 53 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,57 @@ logical_axis_rules: [
['q_lora', ['fsdp']],
['kv_lora', ['fsdp']],
['exp_with_fsdp', 'fsdp'],
# All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp)
['activation_heads', []],
['activation_kv_heads', []],
['activation_length', []],
['activation_attn_length', []],
['activation_attn_length_no_exp', []],
['activation_length_no_exp', []],
['activation_norm_length', []],
['activation_q_length', []],
['activation_q_length_no_exp', []],
['prefill_activation_length', []],
['prefill_activation_norm_length', []],
['activation_kv_length', []],
['activation_attn_embed', []],
['activation_embed', []],
['activation_mlp', []],
['activation_kv', []],
['activation_kv_head_dim', []],
['activation_vocab', []],
['activation_stage', []],
['activation_exp', []],
['decode_length', []],
['mlp', []],
['mlp_no_fsdp', []],
['vocab', []],
['heads', []],
['q_heads', []],
['kv_heads', []],
['embed_tensor_transpose', []],
['q_lora_up_proj', []],
['kv_lora_up_proj', []],
['norm', []],
['layers', []],
['qkv', []],
['kv', []],
['kv_head_dim', []],
['cache_batch_prefill', []],
['cache_batch', []],
['cache_heads_none', []],
['cache_heads', []],
['cache_kv', []],
['cache_sequence', []],
['exp', []],
['paged_kv_heads', []],
['num_pages', []],
['tokens_per_page', []],
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['num_activations', []],
['engram_dim', []],
['mhc', []],
['diloco', []],
]
30 changes: 3 additions & 27 deletions src/maxtext/configs/decoupled_base_test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
# optional cloud features.
# Inherits from base.yml so that logical_axis_rules, mesh_axes, NNX flags, and all other
# model defaults are kept in sync. Overrides only cloud-coupled paths and optional cloud features.
base_config: base.yml

# Output goes to a local relative directory so tests do not require GCS.
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
Expand Down Expand Up @@ -34,34 +35,9 @@ attention: "dot_product"
dump_hlo: false
jax_cache_dir: ""

# Neutral parallelism (single device) for local tests.
ici_data_parallelism: 1
ici_tensor_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1
ici_sequence_parallelism: 1
ici_context_parallelism: 1
ici_tensor_transpose_parallelism: 1
ici_tensor_sequence_parallelism: 1
ici_autoregressive_parallelism: 1
ici_fsdp_parallelism: 1
ici_fsdp_transpose_parallelism: 1
# Allow higher unsharded parameter percentage for small device count
sharding_tolerance: 0.3

# DCN dimensions to 1 (no multi-slice expectation locally).
dcn_data_parallelism: 1
dcn_tensor_parallelism: 1
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_sequence_parallelism: 1
dcn_context_parallelism: 1
dcn_tensor_transpose_parallelism: 1
dcn_tensor_sequence_parallelism: 1
dcn_autoregressive_parallelism: 1
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1

# Config logging off unless a test overrides.
log_config: false

Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ class HardwareAndMesh(BaseModel):
optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.")
shardy: bool = Field(True, description="Whether to use shardy XLA backend.")
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")


class LayoutAndSharding(BaseModel):
Expand Down
32 changes: 26 additions & 6 deletions src/maxtext/experimental/rl/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,23 +546,43 @@ def setup_train_loop(
max_logging.log("Training mesh used for the workload")
num_inference_devices = config.inference_devices_per_replica * config.inference_replicas
training_devices = jax.devices()[num_inference_devices:]
model = mt.from_config(config, devices=training_devices)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
model = mt.from_config(config, devices=training_devices)
mesh = model.mesh
max_logging.log("Inference mesh used for the workload")
inference_devices = jax.devices()[:num_inference_devices]
inference_model = mt.from_config(config_inference, devices=inference_devices)
if config_inference.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
inference_model = mt.from_config(config_inference, devices=inference_devices)
inference_mesh = inference_model.mesh
init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh)
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model)
if config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng)
checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn)

with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION):
data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh)
state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
data_iterator, config, mesh, checkpoint_manager, init_state_fn
)

# create inference_state_mesh_shardings from inference_mesh
if config_inference.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_inference_state_fn = functools.partial(
maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng
)
inference_state_mesh_shardings = maxtext_utils.get_abstract_state(
inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False
config_inference, inference_mesh, init_inference_state_fn, is_training=False
)[2]
if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
Expand Down Expand Up @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None):
data_buffer = []
data_buffer_lock = threading.Lock()

start_step = get_first_step(state) # this is the start_step for training
start_step = get_first_step(model, state) # this is the start_step for training
prof = profiler.Profiler(config, offset_step=start_step)
inference_prof = profiler.Profiler(config_inference, offset_step=start_step)
data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder)
Expand Down
21 changes: 16 additions & 5 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None):

# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
if config.pure_nnx:
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL)
self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None))

self.abstract_params = None
Expand Down Expand Up @@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
rng1, rng2, rng3 = jax.random.split(rng, 3)
if params:
print("Resharding given params")
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng)
_, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state(
self.model, None, self.config, rng, self._mesh, False
self.config, self._mesh, init_state_fn, False
)
# reshard given params based on shardings from config in MaxEngine
params = jax.device_put(params, state_mesh_shardings.params)
state = maxtext_utils.init_decode_state(None, params)
state = max_utils.unbox_logicallypartioned(state)
else:
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(
self.model, self.config, rng1, self._mesh, None
)
if self.config.pure_nnx:
# NNX has a different function to init the training state.
raise NotImplementedError("Pure NNX support has not been implemented yet.")
else:
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1)
state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,14 @@ def __init__(
elif self.is_qwen3_next:
self.query_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
)
self.key_norm = Qwen3NextRMSNorm(
num_features=self.config.head_dim,
eps=self.config.normalization_layer_epsilon,
epsilon=self.config.normalization_layer_epsilon,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
rngs=self.rngs,
Expand Down
Loading
Loading