diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..888cf4d2d1 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -35,7 +35,6 @@ """ import argparse -import functools import gc import os import sys @@ -88,10 +87,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name mesh = Mesh(devices_array, cfg.mesh_axes) quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,12 +98,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) - state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) + state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index f8184a8e8e..edcdbd9414 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1139,9 +1139,8 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: False -pure_nnx_decoder: False -pure_nnx: False +enable_nnx: false +pure_nnx_decoder: false ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 315758015a..b6c38ee00c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -827,7 +827,6 @@ class HardwareAndMesh(BaseModel): optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") - pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") class LayoutAndSharding(BaseModel): diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..100434ef74 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -546,43 +546,23 @@ def setup_train_loop( max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = mt.from_config(config, devices=training_devices) + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - inference_model = mt.from_config(config_inference, devices=inference_devices) + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) - learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) - checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) + init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - data_iterator, config, mesh, checkpoint_manager, init_state_fn + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) # create inference_state_mesh_shardings from inference_mesh - if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( - config_inference, inference_mesh, init_inference_state_fn, is_training=False + inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -717,7 +697,7 @@ def train_loop(config, config_inference, recorder, state=None): data_buffer = [] data_buffer_lock = threading.Lock() - start_step = get_first_step(model, state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) inference_prof = profiler.Profiler(config_inference, offset_step=start_step) data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 23cd2387db..02a2f392c2 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -113,10 +113,7 @@ def __init__(self, config: Any, devices: Any | None = None): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -232,25 +229,17 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( - self.config, self._mesh, init_state_fn, False + self.model, None, self.config, rng, self._mesh, False ) # reshard given params based on shardings from config in MaxEngine params = jax.device_put(params, state_mesh_shardings.params) state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) - state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) + state, self.state_mesh_annotations = maxtext_utils.setup_decode_state( + self.model, self.config, rng1, self._mesh, None + ) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py deleted file mode 100644 index 9ef0e6dffd..0000000000 --- a/src/maxtext/layers/train_state_nnx.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2023–2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" The NNX Unified TrainState. """ - -from typing import Any - -from flax import nnx - - -class TrainStateNNX(nnx.Module): - """ - A unified container for NNX models and optimizers. - This replaces Linen's TrainState for checkpointing. - - Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} - TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} - """ - - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): - self.model = model - self.optimizer = optimizer - - def apply_gradients(self, grads: Any): - """ - Mimics the Linen apply_gradients function. - Updates the optimizer state, applies updates to parameters, - and increments the step counter. - """ - if self.optimizer is None: - raise RuntimeError( - "Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. " - "This usually happens when the state was created for inference only." - ) - self.optimizer.update(self.model, grads) diff --git a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py index c7f6bd4740..7cc8f5b658 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py @@ -85,7 +85,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(model, state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index ea05add5c6..b8ab2043b4 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -77,10 +77,8 @@ VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() -def get_first_step(model, state): - if isinstance(model, nn.Module): - return int(state.step) - return int(state.optimizer.step.get_value()) +def get_first_step(state): + return int(state.step) # ----------------------------------------------------------------------------- @@ -545,7 +543,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(model, state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 679768d56b..78392a388a 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -27,7 +27,6 @@ from typing import Sequence from absl import app -from flax import nnx from flax.linen import partitioning as nn_partitioning import jax from jax.experimental.serialize_executable import serialize @@ -37,7 +36,6 @@ from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations -from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.diloco import diloco @@ -46,8 +44,6 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding -from maxtext.utils import maxtext_utils_nnx -from maxtext.utils import model_creation_utils # pylint: disable=too-many-positional-arguments @@ -97,10 +93,7 @@ def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) - if config.pure_nnx: - _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) - else: - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -110,39 +103,18 @@ def get_shaped_inputs(topology_mesh, config): _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) - if config.pure_nnx: - - def create_train_state_fn(): - nnx_model = _create_model_partial() - optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(nnx_model, optimizer) - - init_state_fn = create_train_state_fn - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) - # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) - - if config.pure_nnx: - # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. - logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) - # For NNX, get_functional_train_with_signature expects the graphdef (static structure), - # not the raw model — mirroring how the training loop does nnx.split(train_state). - with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) - model = graphdef - else: - # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( + model, tx, config, example_rng, topology_mesh + ) + + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng - else: - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) + shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -309,20 +281,12 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) - if config.pure_nnx: - maxtext_utils.print_shardings_params( - shaped_train_args[0], - state_mesh_shardings, - topology_mesh, - logical_annotations, - ) - else: - maxtext_utils.print_shardings_params( - shaped_train_args[0].params, - state_mesh_shardings.params, - topology_mesh, - logical_annotations.params, - ) + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..7c520cc470 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -22,7 +22,6 @@ The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. """ -import functools import os.path from typing import Sequence @@ -43,6 +42,8 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +Transformer = models.transformer_as_linen + def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" @@ -92,20 +93,12 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( - None, config, mesh, checkpoint_manager, init_state_fn + model, None, tx, config, rng, mesh, checkpoint_manager ) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") @@ -116,10 +109,7 @@ def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 36e612a3f9..4be05ff7e1 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -30,7 +30,6 @@ """ -import functools import os from typing import Any, Sequence @@ -175,19 +174,12 @@ def __init__(self, config: Any, rng: PRNGKeyType): # Model and quantization config self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) - - self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( + model, None, self.config, self.rng, self._mesh, False + ) def load_and_quantize(self) -> None: """ diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 24099ef22a..03095edd73 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,7 +14,6 @@ """ Common LoRA utils needed to support LoRA adapters.""" -from functools import partial import json import jax @@ -167,12 +166,7 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bdec9c1f10..675f920357 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -196,11 +196,8 @@ def get_train_input_output_trees(func, input_args, input_kwargs): serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file) shaped_batch = get_shaped_batch(config) - if config.pure_nnx: - shaped_input_args = (state, shaped_batch) - else: - example_rng = jax.random.PRNGKey(0) - shaped_input_args = (state, shaped_batch, example_rng) + example_rng = jax.random.PRNGKey(0) + shaped_input_args = (state, shaped_batch, example_rng) shaped_input_kwargs = {} in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices) @@ -1063,13 +1060,14 @@ def get_abstract_param(model, config): return abstract_vars -def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): +def setup_decode_state(model, config, rng, mesh, checkpoint_manager): """Setup decode state by loading params from a checkpoint. Args: + model: the flax model to initialize config: config object + rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: Checkpoint manager - init_state_fn: function to initialize the model state Returns: state: state with decode params loaded from the checkpoint @@ -1079,12 +1077,12 @@ def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): # generate random params max_logging.log("No decode checkpoint specified - generating random weights.") state, state_mesh_annotations, _, _ = setup_initial_state( - None, config, mesh, checkpoint_manager, init_state_fn, False + model, None, None, config, rng, mesh, checkpoint_manager, False ) else: # Load params from checkpoint max_logging.log(f"Loading decode params from {config.load_parameters_path}") - unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(config, mesh, init_state_fn, False) + unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False) with nn_partitioning.axis_rules(config.logical_axis_rules): params = checkpointing.load_params_from_path( config.load_parameters_path, @@ -1099,35 +1097,40 @@ def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): return state, state_mesh_annotations -def setup_training_state(data_iterator, config, mesh, checkpoint_manager, init_state_fn): +def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): is_training = True return setup_initial_state( + model, data_iterator, + tx, config, + rng, mesh, checkpoint_manager, - init_state_fn, is_training, ) def setup_initial_state( + model, data_iterator, + tx, config, + rng, mesh, checkpoint_manager, - init_state_fn, is_training=True, ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: - data_iterator: data iterator + model: the flax model to initialize + tx: the optax.GradientTransformation config: config object + rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: an Orbax checkpointing.CheckpointManager object - init_state_fn: function to initialize the training state is_training: True to initialize training state, False for decode state Returns: @@ -1136,7 +1139,7 @@ def setup_initial_state( """ unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( - config, mesh, init_state_fn, is_training + model, tx, config, rng, mesh, is_training ) # Initialization @@ -1171,14 +1174,14 @@ def setup_initial_state( # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] else: - init_state_partial = init_state_fn + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) init_state_partial.__name__ = "initialize_state" # pylint: disable=not-callable state = jax.jit( init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings, - )() + )(rng) if raw_params: # If we loaded a partial state, we need to merge it. state = state.replace(params=raw_params) @@ -1187,8 +1190,8 @@ def setup_initial_state( return state, state_mesh_annotations, state_mesh_shardings, data_iterator -def get_logical_annotations(config, mesh, init_state_fn): - init_state_partial = init_state_fn +def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) @@ -1196,9 +1199,9 @@ def get_logical_annotations(config, mesh, init_state_fn): return logical_annotations -def get_abstract_state(config, mesh, init_state_fn, is_training=True): +def get_abstract_state(model, tx, config, rng, mesh, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = init_state_fn + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py deleted file mode 100644 index 7378928ef2..0000000000 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2023–2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Utils for MaxText NNX. """ - -from functools import partial -from typing import Callable - -from flax import nnx -import jax -from jax.sharding import Mesh, NamedSharding - -from maxtext.utils import max_logging -from maxtext.configs import pyconfig - - -def create_nnx_rngs( - config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None -) -> nnx.Rngs: - """ - Create NNX Rngs - - Args: - config: the configuration - is_training: if the Rngs are for training - rng_key: the Rng key - - Returns: - The NNX Rngs - """ - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - - if is_training: - return nnx.Rngs( - params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) - ) - return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference - - -def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: - """Get named sharding from NNX abstract state. - - Args: - abstract_state: NNX model abstract state created from nnx.get_abstract_model. - - Returns: - named sharding structure - """ - # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we - # get the existing sharding from the abstract_state. - # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) - return jax.tree.map( - lambda x: x.sharding, - abstract_state, - is_leaf=lambda x: isinstance(x, jax.ShapeDtypeStruct), - ) - - -def get_partition_spec_nnx(named_sharding: nnx.State) -> nnx.State: - """Get mesh partition spec from named sharding. - - Args: - named_sharding: NNX model named sharding. - - Returns: - mesh partition spec - """ - # The leaf is of type NamedSharding. - return jax.tree.map( - lambda x: x.spec, - named_sharding, - is_leaf=lambda x: isinstance(x, NamedSharding), - ) - - -def set_named_sharding_nnx(abstract_state: nnx.State, named_sharding: nnx.State) -> nnx.State: - """Set named sharding to NNX abstract state. - - Args: - abstract_state: NNX model abstract state created from nnx.get_abstract_model(). - named_sharding: named sharding. It must have the same tree structure with abstract_state. - - Returns: - updated abstract_state - """ - return jax.tree.map(lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), abstract_state, named_sharding) - - -def move_memory_to_host(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: - """ - Change the memory_kind of the NamedSharding to "pinned_host". This function can be - called by jax.tree_util.tree_map_with_path on a NNX state structure. - - Args: - path: the tree path tuple - x: the NamedSharding corresponding to the path - - Returns: - the NamedSharding with memory_kind set to "pinned_host" - """ - max_logging.log(f"max_utils.py: Moving {path} to host") - # Create the new sharding with the target memory kind - return x.with_memory_kind(kind="pinned_host") - - -def move_memory_to_device(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: - """ - Change the memory_kind of the NamedSharding to "device". This function can be - called by jax.tree_util.tree_map_with_path on a NNX state structure. - - Args: - path: the tree path tuple - x: the NamedSharding corresponding to the path - - Returns: - the NamedSharding with memory_kind set to "device" - """ - max_logging.log(f"max_utils.py: Moving {path} to device") - # Create the new sharding with the target memory kind - return x.with_memory_kind(kind="device") - - -def create_nnx_sharded_model( - abstract_model: nnx.Module, - init_fn: Callable, - mesh: Mesh | None = None, - named_sharding: nnx.State | None = None, -) -> nnx.Module: - """ - Create the model with the given sharding. - - Args: - abstract_model: the abstract model - init_fn: the model init function - mesh: the device mesh - named_sharding: the given sharding - - Returns: - The initialized sharded model - """ - graphdef, abstract_state = nnx.split(abstract_model) - if named_sharding is None: - # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) - # we get the sharding directly from it. - named_sharding = get_named_sharding_nnx(abstract_state) - - if mesh is None: - mesh = abstract_model.mesh - - # JIT a function that creates the model state with proper sharding from the start. - # By providing out_shardings, we instruct JAX to produce sharded output directly, - # avoiding a large intermediate allocation on a single device. - @partial(jax.jit, out_shardings=named_sharding) - def create_sharded_state(): - model = init_fn() - return jax.lax.with_sharding_constraint(nnx.state(model), named_sharding) - - # Create the model with sharded parameters. - with jax.set_mesh(mesh): - sharded_state = create_sharded_state() - return nnx.merge(graphdef, sharded_state) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 49fb9d3490..f492744b24 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -226,38 +226,6 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. - - Unlike create_nnx_model, this does not shard parameters or load checkpoints. - It only builds the abstract shape/dtype structure needed by get_abstract_state - and optimizer construction (e.g. Muon). - - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key - - Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. - """ - - def _create_model(rng_key=None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - rngs = nnx.Rngs(params=rng_key, dropout=1) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, rng_key=rng_key) - - with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - - return _create_model_partial, abstract_model - - def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..1aaf800030 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -19,7 +19,6 @@ # See github.com/google/maxtext/issues/20 for more import datetime -from functools import partial import os from typing import Sequence @@ -52,21 +51,11 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + model = from_config(config) mesh = model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) - _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) - checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) - - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) + init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh) + + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() @@ -92,10 +81,10 @@ def checkpoint_loop(config, state=None): "STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}" ) else: # Checkpoint was unavailable, state needs to be initialized - state, _, _, _ = maxtext_utils.setup_training_state(None, config, mesh, checkpoint_manager, init_state_fn) + state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(model, state) # this is the start_step for training + start_step = get_first_step(state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() diff --git a/src/maxtext/utils/standalone_dataloader.py b/src/maxtext/utils/standalone_dataloader.py index ed77e61b35..e8a942fa1e 100644 --- a/src/maxtext/utils/standalone_dataloader.py +++ b/src/maxtext/utils/standalone_dataloader.py @@ -38,13 +38,13 @@ def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, model, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) + _, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) data_loader = DataLoader(config, mesh, data_iterator, None) example_batch = None start = datetime.datetime.now() - start_step = get_first_step(model, state) + start_step = get_first_step(state) example_batch = data_loader.load_next_batch() jax.block_until_ready(example_batch) first_end = datetime.datetime.now() diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 9e0a00c8e6..2ed71a6e3f 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -16,8 +16,6 @@ """Utils that are only interesting for training in MaxText.""" import os -from functools import partial - import jax import functools from flax.linen import partitioning as nn_partitioning @@ -35,17 +33,12 @@ from maxtext.trainers.diloco import diloco -def create_training_optimizer(config, model): - """Creates the optimizer and learning rate schedule.""" +def create_training_tools(config, model, mesh): + """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" + init_rng = jax.random.PRNGKey(config.init_weights_seed) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon tx = optimizers.get_optimizer(config, learning_rate_schedule, model) - return learning_rate_schedule, tx - - -def create_checkpoint_manager(config, mesh, init_state_fn): - """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" - # pass in model for muon logger = checkpointing.setup_checkpoint_logger(config) if config.enable_multi_tier_checkpointing: checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( @@ -54,7 +47,7 @@ def create_checkpoint_manager(config, mesh, init_state_fn): mesh, ) elif config.enable_emergency_checkpoint: - abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( config.local_checkpoint_directory, config.checkpoint_dir, @@ -94,10 +87,10 @@ def create_checkpoint_manager(config, mesh, init_state_fn): config.checkpoint_todelete_full_path, ) - return checkpoint_manager + return init_rng, checkpoint_manager, learning_rate_schedule, tx -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=None): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" if config.enable_diloco: functional_train = train_step @@ -119,9 +112,7 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit if config.compiled_trainstep_file != "": max_logging.log("Loading the compiled function...") - # For NNX, model is the GraphDef (no .mesh); use the mesh passed explicitly instead. - execution_mesh = mesh if mesh is not None else model.mesh - execution_devices = execution_mesh.devices.flatten().tolist() + execution_devices = model.mesh.devices.flatten().tolist() # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices) max_logging.log("Loaded compiled function!") @@ -176,9 +167,7 @@ def jit_train_and_eval_step( train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) data_sharding = sharding.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step( - config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh - ) + p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) p_eval_step = None if eval_data_iterator: p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) @@ -210,21 +199,9 @@ def setup_train_loop(config, recorder, devices=None): from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): - is_training = True - init_rng = jax.random.PRNGKey(config.init_weights_seed) - if config.pure_nnx: - # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = model_creation_utils.from_config(config, devices) + model = model_creation_utils.from_config(config, devices) mesh = model.mesh - learning_rate_schedule, tx = create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) - checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) + init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) @@ -250,7 +227,7 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - data_iterator, config, mesh, checkpoint_manager, init_state_fn + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) if config.enable_diloco: @@ -273,14 +250,14 @@ def setup_train_loop(config, recorder, devices=None): # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) max_utils.print_non_trivial_mesh_axis(model.mesh) maxtext_utils.print_shardings_params( state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params ) if config.use_dpo: - abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) + abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index cae8b9e4d3..e4e9f4fe8a 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -38,7 +38,7 @@ from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils -from tests.post_training.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs +from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs import numpy as np import torch import transformers @@ -73,27 +73,17 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) # With checkpoint - if self.cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) - self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules) self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # Without checkpoint - if self.cfg_no_ckpt_loading.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model_no_ckpt_loading = models.transformer_as_linen( - config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN - ) - init_state_fn = functools.partial( - maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng - ) - self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn) + self.model_no_ckpt_loading = models.transformer_as_linen( + config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN + ) + self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state( + self.model_no_ckpt_loading, self.cfg_no_ckpt_loading, self.rng, mesh, None + ) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", diff --git a/tests/post_training/integration/grpo_correctness.py b/tests/post_training/integration/grpo_correctness.py index adefc03a7e..44a3e28df7 100644 --- a/tests/post_training/integration/grpo_correctness.py +++ b/tests/post_training/integration/grpo_correctness.py @@ -13,7 +13,6 @@ # limitations under the License. """GRPO correctness tests""" -import functools import os import unittest @@ -61,13 +60,8 @@ def setUp(self): self.rng = jax.random.PRNGKey(42) devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) - if self.cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) - self.state, _ = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", add_bos_token=False, @@ -127,7 +121,7 @@ def _prepare_maxtext_inputs(self): ) def _prepare_trl_inputs(self): - """Prepare inputs for TRL model.""" + """Prepare TRL inputs.""" tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt") input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1) attention_mask = torch.cat( diff --git a/tests/post_training/integration/grpo_trainer_correctness_test.py b/tests/post_training/integration/grpo_trainer_correctness_test.py index b880a0e678..9a2cfd4078 100644 --- a/tests/post_training/integration/grpo_trainer_correctness_test.py +++ b/tests/post_training/integration/grpo_trainer_correctness_test.py @@ -25,7 +25,6 @@ pytest tests/post_training/integration/grpo_trainer_correctness_test.py """ -import functools import os import subprocess import sys @@ -73,13 +72,8 @@ def setup_maxtext_model(config, mesh): init_rng = jax.random.PRNGKey(config.init_weights_seed) quant = quantizations.configure_quantization(config) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) - state, state_mesh_annotations = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + state, state_mesh_annotations = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, config.logical_axis_rules) data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) reference_params = jax.tree.map(jnp.copy, state.params["params"]) diff --git a/tests/post_training/integration/sft_trainer_correctness_test.py b/tests/post_training/integration/sft_trainer_correctness_test.py index 89ac19d0f3..beeb2036d9 100644 --- a/tests/post_training/integration/sft_trainer_correctness_test.py +++ b/tests/post_training/integration/sft_trainer_correctness_test.py @@ -24,7 +24,6 @@ pytest tests/post_training/integration/sft_trainer_correctness_test.py """ -import functools import os.path import subprocess import sys @@ -118,13 +117,8 @@ def setup_maxtext_model(config): quant = quantizations.configure_quantization(config) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) - state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) return maxtext_model, state, init_rng diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 4850e972b3..a65a905c7f 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -14,9 +14,8 @@ """Tests for the common MaxText utilities""" -import functools -from typing import Any from collections.abc import Callable +from typing import Any import unittest from unittest.mock import MagicMock, Mock @@ -29,7 +28,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec from maxtext.configs import pyconfig -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -352,31 +351,18 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) def test_setup_decode_state(self): rng = random.PRNGKey(0) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) - state, _ = maxtext_utils.setup_decode_state(self.config, self.mesh, None, init_state_fn) + state, _ = maxtext_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) - state, _, _, _ = maxtext_utils.setup_initial_state(None, self.config, self.mesh, None, init_state_fn) + state, _, _, _ = maxtext_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) @@ -945,8 +931,7 @@ def setUp(self): def test_get_abstract_state(self): """Tests that get_abstract_state returns abstract arrays.""" # get_abstract_state returns a tuple, the first element is the abstract state. - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self.mesh, init_state_fn) + abstract_state, _, _ = maxtext_utils.get_abstract_state(self.model, self.tx, self.config, self.rng, self.mesh, None) # Check that params are abstract param_leaves = jax.tree_util.tree_leaves(abstract_state.params) @@ -957,435 +942,5 @@ def test_get_abstract_state(self): self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) -class TestGetFunctionalTrainWithSignature(unittest.TestCase): - """Tests for get_functional_train_with_signature.""" - - def _make_mock_step(self): - def train_step(_model, _config, _state_shardings, _params_shardings, state, _batch, _rng=None): - return state, {} - - return train_step - - def test_returns_five_tuple(self): - step = self._make_mock_step() - result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" - ) - self.assertEqual(len(result), 5) - - def test_functional_train_has_correct_name(self): - step = self._make_mock_step() - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" - ) - self.assertEqual(fn.__name__, "train_step") - - def test_in_shardings_structure(self): - step = self._make_mock_step() - _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" - ) - # (state, batch, rng) - self.assertEqual(len(in_shardings), 3) - self.assertIsNone(in_shardings[2]) # rng sharding is None - - def test_donate_argnums_is_zero(self): - step = self._make_mock_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" - ) - self.assertEqual(donate_argnums, 0) - - def test_functional_train_is_partial(self): - """functional_train should partially apply model and config.""" - received = {} - - def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): - received["model"] = model - received["config"] = config - return state, {} - - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") - fn("state", "batch") - self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") - - -class TestGetFunctionalEvalWithSignature(unittest.TestCase): - """Tests for get_functional_eval_with_signature.""" - - def _make_mock_eval_step(self): - def eval_step(_model, _config, _state, _batch, _rng=None): - return {} - - return eval_step - - def test_returns_five_tuple(self): - step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") - self.assertEqual(len(result), 5) - - def test_functional_eval_has_correct_name(self): - step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") - self.assertEqual(fn.__name__, "eval_step") - - def test_out_shardings_is_none(self): - step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") - self.assertIsNone(out_shardings) - - def test_donate_argnums_is_empty(self): - step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") - self.assertEqual(donate_argnums, ()) - - -class TestGetShapedBatch(unittest.TestCase): - """Tests for get_shaped_batch.""" - - def _make_cfg(self, *, enable_diloco=False, use_multimodal=False, use_audio=False): - cfg = MagicMock() - cfg.enable_diloco = enable_diloco - cfg.global_batch_size_to_load = 4 - cfg.max_target_length = 16 - cfg.use_multimodal = use_multimodal - cfg.use_audio = use_audio - if enable_diloco: - cfg.num_diloco_replicas = 2 - return cfg - - def test_standard_keys_present(self): - batch = maxtext_utils.get_shaped_batch(self._make_cfg()) - for key in ( - "inputs", - "inputs_position", - "inputs_segmentation", - "targets", - "targets_position", - "targets_segmentation", - ): - self.assertIn(key, batch) - - def test_standard_shape(self): - cfg = self._make_cfg() - batch = maxtext_utils.get_shaped_batch(cfg) - expected_shape = (cfg.global_batch_size_to_load, cfg.max_target_length) - self.assertEqual(batch["inputs"].shape, expected_shape) - - def test_diloco_shape(self): - cfg = self._make_cfg(enable_diloco=True) - batch = maxtext_utils.get_shaped_batch(cfg) - expected_shape = ( - cfg.num_diloco_replicas, - cfg.global_batch_size_to_load // cfg.num_diloco_replicas, - cfg.max_target_length, - ) - self.assertEqual(batch["inputs"].shape, expected_shape) - - def test_no_image_key_without_multimodal(self): - batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_multimodal=False)) - self.assertNotIn("images", batch) - - def test_no_audio_key_without_audio(self): - batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_audio=False)) - self.assertNotIn("audios", batch) - - def test_all_values_are_shape_dtype_struct(self): - batch = maxtext_utils.get_shaped_batch(self._make_cfg()) - for v in batch.values(): - self.assertIsInstance(v, jax.ShapeDtypeStruct) - - -class TestShouldPreventCseInRemat(unittest.TestCase): - """Tests for should_prevent_cse_in_remat.""" - - def _make_cfg(self, scan_layers=False, gradient_accumulation_steps=1, hardware="tpu"): - cfg = MagicMock() - cfg.scan_layers = scan_layers - cfg.gradient_accumulation_steps = gradient_accumulation_steps - cfg.hardware = hardware - return cfg - - def test_scan_layers_returns_false(self): - self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg(scan_layers=True))) - - def test_gpu_with_grad_accum_returns_false(self): - cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu") - self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) - - def test_gpu_multiprocess_with_grad_accum_returns_false(self): - cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu_multiprocess") - self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) - - def test_tpu_with_grad_accum_returns_true(self): - cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="tpu") - self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(cfg)) - - def test_default_case_returns_true(self): - self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg())) - - -class TestCalculateTokensTrainingPerDevice(unittest.TestCase): - """Tests for calculate_tokens_training_per_device.""" - - def test_basic_calculation(self): - cfg = MagicMock() - cfg.max_target_length = 128 - cfg.per_device_batch_size = 2 - cfg.gradient_accumulation_steps = 4 - result = maxtext_utils.calculate_tokens_training_per_device(cfg) - self.assertEqual(result, 128 * 2 * 4) - - -class TestCalculateIndexerMaskRatio(unittest.TestCase): - """Tests for calculate_indexer_mask_ratio.""" - - def test_half_topk(self): - # K=T/2: ratio=0.5, mask = 0.5 - 0.5*0.25 = 0.375 - result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=4, max_target_length=8) - self.assertAlmostEqual(result, 0.375, places=6) - - def test_full_topk_equals_dense(self): - # K=T: ratio=1, mask = 1 - 0.5 = 0.5 (same as causal) - result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=8, max_target_length=8) - self.assertAlmostEqual(result, 0.5, places=6) - - def test_small_topk(self): - # K=1, T=100: ratio=0.01, mask ≈ 0.01 - 0.5*0.0001 ≈ 0.00995 - result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=1, max_target_length=100) - expected = 0.01 - 0.5 * (0.01**2) - self.assertAlmostEqual(result, expected, places=8) - - -class TestCalculateFfnMatmulTflops(unittest.TestCase): - """Tests for calculate_ffn_mamtul_tflops_per_device.""" - - def _make_cfg(self, num_activations=2): - cfg = MagicMock() - cfg.per_device_batch_size = 1 - cfg.max_target_length = 64 - cfg.emb_dim = 512 - cfg.mlp_activations = ["silu"] * num_activations - return cfg - - def test_total_flops_positive(self): - result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(self._make_cfg(), mlp_dim=2048) - self.assertGreater(result, 0) - - def test_scales_with_mlp_dim(self): - cfg = self._make_cfg() - small = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=1024) - large = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=4096) - self.assertGreater(large, small) - - def test_single_activation(self): - """With one activation, ffn1 uses 1x mlp_dim.""" - cfg = self._make_cfg(num_activations=1) - result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=2048) - expected_ffn1 = 2 * 1 * 64 * 2048 * 512 * 1 - expected_ffn2 = 2 * 1 * 64 * 2048 * 512 - self.assertEqual(result, expected_ffn1 + expected_ffn2) - - -class TestGetDenseMoeLayers(unittest.TestCase): - """Tests for get_dense_moe_layers.""" - - def _make_cfg(self, decoder_block, num_decoder_layers=32, first_num_dense_layers=3, interleave_moe_layer_step=4): - cfg = MagicMock() - cfg.decoder_block = decoder_block - cfg.num_decoder_layers = num_decoder_layers - cfg.first_num_dense_layers = first_num_dense_layers - cfg.interleave_moe_layer_step = interleave_moe_layer_step - return cfg - - def test_deepseek_block(self): - cfg = self._make_cfg(DecoderBlockType.DEEPSEEK, num_decoder_layers=32, first_num_dense_layers=3) - dense, moe = maxtext_utils.get_dense_moe_layers(cfg) - self.assertEqual(dense, 3) - self.assertEqual(moe, 29) - - def test_llama4_block(self): - cfg = self._make_cfg(DecoderBlockType.LLAMA4, num_decoder_layers=16, interleave_moe_layer_step=4) - dense, moe = maxtext_utils.get_dense_moe_layers(cfg) - self.assertEqual(moe, 4) # 16 // 4 - self.assertEqual(dense, 12) # 16 - 4 - - def test_qwen3_next_block(self): - cfg = self._make_cfg(DecoderBlockType.QWEN3_NEXT, num_decoder_layers=8) - dense, moe = maxtext_utils.get_dense_moe_layers(cfg) - self.assertEqual(dense, 0) - self.assertEqual(moe, 8) - - def test_unsupported_block_raises(self): - cfg = self._make_cfg(DecoderBlockType.DEFAULT) - with self.assertRaises(ValueError): - maxtext_utils.get_dense_moe_layers(cfg) - - -class TestCalculatePrefillTflops(unittest.TestCase): - """Tests for calculate_prefill_tflops_per_device.""" - - def _make_cfg(self, num_query_heads=8, num_decoder_layers=2, head_dim=64): - cfg = MagicMock() - cfg.num_query_heads = num_query_heads - cfg.num_decoder_layers = num_decoder_layers - cfg.head_dim = head_dim - return cfg - - def test_returns_three_positive_values(self): - cfg = self._make_cfg() - total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( - num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False - ) - self.assertGreater(total, 0) - self.assertGreater(lw, 0) - self.assertGreater(attn, 0) - - def test_total_is_sum_of_parts(self): - cfg = self._make_cfg() - total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( - num_model_parameters=500_000, prefill_length=64, config=cfg, log=False - ) - self.assertAlmostEqual(total, lw + attn, places=10) - - def test_scales_with_prefill_length(self): - cfg = self._make_cfg() - _, _, attn_short = maxtext_utils.calculate_prefill_tflops_per_device( - num_model_parameters=1_000_000, prefill_length=64, config=cfg, log=False - ) - _, _, attn_long = maxtext_utils.calculate_prefill_tflops_per_device( - num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False - ) - # Attention scales quadratically with prefill length - self.assertGreater(attn_long, attn_short * 3) - - -class TestSetupTrainingState(unittest.TestCase): - """Tests for setup_training_state (thin wrapper over setup_initial_state).""" - - def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) - devices_array = maxtext_utils.create_device_mesh(self.config) - self.mesh = Mesh(devices_array, self.config.mesh_axes) - quant = quantizations.configure_quantization(self.config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX path not covered by this test.") - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - - def test_setup_training_state_returns_train_state(self): - rng = jax.random.PRNGKey(0) - tx = optax.adam(learning_rate=0.001) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) - state, _, _, _ = maxtext_utils.setup_training_state(None, self.config, self.mesh, None, init_state_fn) - self.assertEqual(state.tx, tx) - self.assertNotEqual(state.opt_state, {}) - - -class TestGetLogicalAnnotations(unittest.TestCase): - """Tests for get_logical_annotations.""" - - def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) - devices_array = maxtext_utils.create_device_mesh(self.config) - self.mesh = Mesh(devices_array, self.config.mesh_axes) - quant = quantizations.configure_quantization(self.config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX path not covered by this test.") - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - self.rng = jax.random.PRNGKey(0) - self.tx = optax.adam(learning_rate=0.001) - - def test_returns_partition_spec_tree(self): - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) - annotations = maxtext_utils.get_logical_annotations(self.config, self.mesh, init_state_fn) - # Result should be a pytree with PartitionSpec leaves - leaves = jax.tree_util.tree_leaves(annotations) - self.assertGreater(len(leaves), 0) - for leaf in leaves: - self.assertIsInstance(leaf, PartitionSpec) - - -class TestSaveQuantizedCheckpoint(unittest.TestCase): - """Tests for save_quantized_checkpoint_if_configured.""" - - def test_raises_when_no_quantization(self): - cfg = MagicMock() - cfg.quantization = "" - with self.assertRaises(AssertionError): - maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) - - @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") - def test_skips_save_when_path_empty(self, mock_ckpt): - cfg = MagicMock() - cfg.quantization = "int8" - cfg.save_quantized_params_path = "" - maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) - mock_ckpt.save_params_to_path.assert_not_called() - - @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") - def test_calls_save_when_path_set(self, mock_ckpt): - cfg = MagicMock() - cfg.quantization = "int8" - cfg.save_quantized_params_path = "/tmp/quantized" - cfg.checkpoint_storage_use_ocdbt = True - cfg.checkpoint_storage_use_zarr3 = True - maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={"w": jnp.ones((2,))}) - mock_ckpt.save_params_to_path.assert_called_once() - - -class TestAddConfigToSummaryWriter(unittest.TestCase): - """Tests for add_config_to_summary_writer.""" - - def test_calls_add_text_for_each_key(self): - cfg = MagicMock() - cfg.get_keys.return_value = {"learning_rate": 0.001, "steps": 100} - mock_writer = MagicMock() - - with unittest.mock.patch("maxtext.utils.max_utils.add_text_to_summary_writer") as mock_add: - maxtext_utils.add_config_to_summary_writer(cfg, mock_writer) - # Should have been called once per config key (process_index==0 in tests) - if jax.process_index() == 0: - self.assertEqual(mock_add.call_count, 2) - - -class TestMaybeDumpJaxpr(unittest.TestCase): - """Tests for maybe_dump_jaxpr.""" - - def test_early_return_when_disabled(self): - cfg = MagicMock() - cfg.dump_jaxpr = False - # Should return immediately without calling any JAX tracing (no exception raised) - maxtext_utils.maybe_dump_jaxpr(cfg, p_train_step=None, train_step_inputs=None) - - -class TestPrintShardingsParams(unittest.TestCase): - """Tests for print_shardings_params — normalization branches.""" - - def setUp(self): - """Build a minimal mesh and sharded param for testing.""" - self.mesh = Mesh(np.array(jax.devices()), ("data",)) - - def _make_simple_params(self): - """Return (params, param_sharding, logical) without a .params attribute.""" - params = {"w": jnp.ones((4,))} - param_sharding = {"w": NamedSharding(self.mesh, PartitionSpec(None))} - logical = {"w": PartitionSpec(None)} - return params, param_sharding, logical - - def test_runs_without_error_dict_inputs(self): - """print_shardings_params should not raise with plain dict inputs.""" - params, param_sharding, logical = self._make_simple_params() - # Should complete without raising - maxtext_utils.print_shardings_params(params, param_sharding, logical) - - def test_runs_without_logical_annotations(self): - """logical_annotations=None should be handled (no logical column).""" - params, param_sharding, _ = self._make_simple_params() - maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) - - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index bed2e699fa..7f8c784176 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2023–2026 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,27 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for model_creation_utils.py.""" +"""Tests for model_creation_utils.""" import dataclasses -import sys import unittest -from unittest.mock import MagicMock, patch import jax import jax.numpy as jnp -import flax.linen as nn -from flax import nnx -from jax.sharding import Mesh + from orbax import checkpoint as ocp -from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL -from maxtext.models import models -from maxtext.utils import maxtext_utils -from maxtext.utils import model_creation_utils +# Import the private helpers under test. from maxtext.utils.model_creation_utils import _fix_restore_args_for_shape_mismatch -from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides # --------------------------------------------------------------------------- @@ -50,8 +41,10 @@ def _is_fake_meta(x): # Monkey-patch the module-level helper so our fake metadata is recognised. -_orig_is_orbax = model_creation_utils._is_orbax_array_metadata # pylint: disable=protected-access -model_creation_utils._is_orbax_array_metadata = _is_fake_meta # pylint: disable=protected-access +import maxtext.utils.model_creation_utils as _mcu + +_orig_is_orbax = _mcu._is_orbax_array_metadata # pylint: disable=protected-access +_mcu._is_orbax_array_metadata = _is_fake_meta # pylint: disable=protected-access def _make_restore_arg(global_shape): @@ -66,34 +59,6 @@ def _make_restore_arg(global_shape): ) -def _make_config(**kwargs): - """Returns a minimal pyconfig suitable for model-creation tests.""" - extra = get_decoupled_parallelism_overrides() - defaults = { - "per_device_batch_size": 1.0, - "run_name": "test", - "enable_checkpointing": False, - "base_num_decoder_layers": 2, - "attention": "dot_product", - "max_target_length": 16, - "base_emb_dim": 256, - "base_num_query_heads": 2, - "base_num_kv_heads": 2, - "max_prefill_predict_length": 4, - } - defaults.update(kwargs) - return pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **defaults, - **extra, - ) - - -def _make_mesh(config): - devices_array = maxtext_utils.create_device_mesh(config) - return Mesh(devices_array, config.mesh_axes) - - class FixRestoreArgsRankGuardTest(unittest.TestCase): """_fix_restore_args_for_shape_mismatch must not touch args when stored rank != model rank.""" @@ -141,275 +106,5 @@ def test_scanned_both_same_rank_shape_mismatch_is_modified(self): self.assertIsNone(arg.global_shape) -class TestGetTransformerModel(unittest.TestCase): - """Tests for get_transformer_model().""" - - def setUp(self): - self.config = _make_config() - self.mesh = _make_mesh(self.config) - - def test_returns_linen_module_when_rngs_is_none(self): - """Without rngs, should return a Linen nn.Module.""" - model = model_creation_utils.get_transformer_model(self.config, self.mesh, quant=None, rngs=None) - self.assertIsInstance(model, nn.Module) - - def test_returns_nnx_module_when_rngs_provided(self): - """With rngs, should return an NNX nnx.Module.""" - model = nnx.eval_shape( - lambda: model_creation_utils.get_transformer_model( - self.config, self.mesh, quant=None, rngs=nnx.Rngs(params=0, dropout=1, aqt=2) - ) - ) - self.assertIsInstance(model, nnx.Module) - - def test_respects_model_mode_prefill(self): - """Linen model created with MODEL_MODE_PREFILL should differ from train mode.""" - linen_train = model_creation_utils.get_transformer_model( - self.config, self.mesh, quant=None, model_mode=MODEL_MODE_TRAIN, rngs=None - ) - linen_prefill = model_creation_utils.get_transformer_model( - self.config, self.mesh, quant=None, model_mode=MODEL_MODE_PREFILL, rngs=None - ) - # Both are still nn.Module instances - self.assertIsInstance(linen_train, nn.Module) - self.assertIsInstance(linen_prefill, nn.Module) - - -class TestCreateModel(unittest.TestCase): - """Tests for create_model().""" - - def setUp(self): - self.config = _make_config() - self.mesh = _make_mesh(self.config) - - def test_returns_linen_model_without_rngs(self): - model = model_creation_utils.create_model(self.config, self.mesh) - self.assertIsInstance(model, nn.Module) - - def test_returns_nnx_model_with_rngs(self): - model = nnx.eval_shape( - lambda: model_creation_utils.create_model(self.config, self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) - ) - self.assertIsInstance(model, nnx.Module) - - def test_model_mode_train_default(self): - """Default model_mode is MODEL_MODE_TRAIN.""" - model = model_creation_utils.create_model(self.config, self.mesh) - self.assertIsInstance(model, nn.Module) - - -class TestFromConfig(unittest.TestCase): - """Tests for from_config().""" - - def setUp(self): - self.config = _make_config() - self.mesh = _make_mesh(self.config) - - def test_linen_path_rngs_none(self): - """from_config with rngs=None should return a Linen nn.Module.""" - model = model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=None) - self.assertIsInstance(model, nn.Module) - - def test_nnx_path_with_rngs(self): - """from_config with rngs provided should return an NNX nnx.Module.""" - model = nnx.eval_shape( - lambda: model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) - ) - self.assertIsInstance(model, nnx.Module) - - def test_mesh_created_from_devices_when_none(self): - """from_config should work when mesh is None (creates mesh internally).""" - model = model_creation_utils.from_config(self.config, devices=None, mesh=None, rngs=None) - self.assertIsInstance(model, nn.Module) - - def test_model_mode_is_forwarded(self): - """from_config should accept and forward model_mode.""" - model = model_creation_utils.from_config(self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL, rngs=None) - self.assertIsInstance(model, nn.Module) - - def test_explicit_shard_mode_creates_mesh_with_explicit_axis_types(self): - """from_config with shard_mode=explicit should create mesh using AxisType.Explicit.""" - cfg = _make_config(shard_mode="explicit") - # Should not raise; mesh is built with AxisType.Explicit for each axis - model = model_creation_utils.from_config(cfg, mesh=None, rngs=None) - self.assertIsInstance(model, nn.Module) - - -class TestCreateNNXAbstractModel(unittest.TestCase): - """Tests for create_nnx_abstract_model().""" - - def setUp(self): - self.config = _make_config() - self.mesh = _make_mesh(self.config) - - def test_returns_tuple_of_callable_and_module(self): - create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) - self.assertTrue(callable(create_fn)) - self.assertIsInstance(abstract_model, nnx.Module) - - def test_abstract_model_has_abstract_arrays(self): - """Abstract model leaves should be ShapeDtypeStruct, not concrete arrays.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) - _, state = nnx.split(abstract_model) - leaves = jax.tree.leaves(state) - self.assertGreater(len(leaves), 0) - for leaf in leaves: - # In abstract state, values are nnx.Variable wrapping abstract shapes/ShapeDtypeStruct - # Concrete jax.Array would have a .devices() method; abstract ones should not be Arrays - self.assertNotIsInstance(leaf, jax.Array) - - def test_create_fn_produces_concrete_model(self): - """The returned create_fn should produce a real (concrete) NNX Module.""" - create_fn, _ = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) - with self.mesh: - concrete = create_fn() - self.assertIsInstance(concrete, nnx.Module) - leaves = jax.tree.leaves(nnx.state(concrete)) - for leaf in leaves: - self.assertIsInstance(leaf, jax.Array) - - def test_works_without_explicit_mesh(self): - """create_nnx_abstract_model should work when mesh=None (from_config creates mesh).""" - create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=None) - self.assertTrue(callable(create_fn)) - self.assertIsInstance(abstract_model, nnx.Module) - - def test_explicit_rng_key_is_used(self): - """Passing a rng_key should not raise and returns valid abstract model.""" - rng_key = jax.random.PRNGKey(42) - create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model( - self.config, mesh=self.mesh, rng_key=rng_key - ) - self.assertTrue(callable(create_fn)) - self.assertIsInstance(abstract_model, nnx.Module) - - def test_prefill_model_mode(self): - """create_nnx_abstract_model should accept MODEL_MODE_PREFILL.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model( - self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL - ) - self.assertIsInstance(abstract_model, nnx.Module) - - -class TestCreateNnxModel(unittest.TestCase): - """Tests for create_nnx_model().""" - - def setUp(self): - self.config = _make_config() - self.mesh = _make_mesh(self.config) - - def test_no_checkpoint_returns_model_and_mesh(self): - """Without load_parameters_path, should return (model, mesh) cleanly.""" - model, mesh = model_creation_utils.create_nnx_model(self.config, self.mesh) - self.assertIsInstance(model, models.Transformer) - self.assertIsInstance(mesh, Mesh) - - def test_mesh_none_uses_abstract_model_mesh(self): - """When mesh=None is passed, the function resolves it from the abstract model.""" - model, mesh = model_creation_utils.create_nnx_model(self.config, mesh=None) - self.assertIsInstance(model, models.Transformer) - self.assertIsInstance(mesh, Mesh) - - def test_explicit_rng_key(self): - """An explicit rng_key should be accepted without error.""" - rng_key = jax.random.PRNGKey(99) - model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, rng_key=rng_key) - self.assertIsInstance(model, models.Transformer) - - def test_inference_mode_disables_dropout_rng(self): - """MODEL_MODE_PREFILL should create rngs without a dropout key.""" - model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, model_mode=MODEL_MODE_PREFILL) - self.assertIsInstance(model, models.Transformer) - - def test_debug_sharding_flag(self): - """debug_sharding=True should execute the sharding-print path without error.""" - cfg = _make_config(debug_sharding=True) - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) - self.assertIsInstance(model, models.Transformer) - - # ---- checkpoint loading: mocked paths ---- - - def _make_linen_metadata_mock(self): - """Mock ocp metadata that looks like a Linen checkpoint.""" - meta = MagicMock() - meta.item_metadata.tree.keys.return_value = ["params"] - meta.item_metadata.tree.get.return_value = {"params": {}} - return meta - - def _make_nnx_metadata_mock(self): - """Mock ocp metadata that looks like an NNX checkpoint.""" - meta = MagicMock() - meta.item_metadata.tree.keys.return_value = ["decoder"] - meta.item_metadata.tree.get.return_value = {} - return meta - - @patch("maxtext.utils.model_creation_utils.ocp") - def test_load_nnx_checkpoint(self, mock_ocp): - """NNX-format checkpoint: restored values are wrapped under a 'value' key.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) - _, abstract_state = nnx.split(abstract_model) - - # Build a fake restored dict with 'value' keys (NNX checkpoint structure). - # Use concrete zero arrays (not ShapeDtypeStruct) so device_put in - # _expand_checkpoint_to_model_shapes receives a valid JAX array. - fake_restored = jax.tree.map( - lambda v: {"value": jnp.zeros(v.value.shape, v.value.dtype)}, - abstract_state, - is_leaf=lambda n: isinstance(n, nnx.Variable), - ) - - mock_ckptr = MagicMock() - mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() - mock_ckptr.restore.return_value = fake_restored - mock_ocp.Checkpointer.return_value = mock_ckptr - mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() - mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} - mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs - - cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/nnx_ckpt") - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) - self.assertIsInstance(model, models.Transformer) - - @patch("maxtext.utils.model_creation_utils.ocp") - def test_load_linen_checkpoint(self, mock_ocp): - """Linen-format checkpoint: restored values are nested under 'params'/'params'.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) - _, abstract_state = nnx.split(abstract_model) - - # Build fake plain-value dict (Linen structure). - # Use concrete zero arrays so device_put in _expand_checkpoint_to_model_shapes - # receives a valid JAX array (not a ShapeDtypeStruct). - fake_params = jax.tree.map( - lambda v: jnp.zeros(v.value.shape, v.value.dtype), - abstract_state, - is_leaf=lambda n: isinstance(n, nnx.Variable), - ) - fake_restored = {"params": {"params": fake_params}} - - mock_ckptr = MagicMock() - mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() - mock_ckptr.restore.return_value = fake_restored - mock_ocp.Checkpointer.return_value = mock_ckptr - mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() - mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} - mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs - - cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/linen_ckpt") - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) - self.assertIsInstance(model, models.Transformer) - - @patch("maxtext.utils.model_creation_utils.ocp") - def test_checkpoint_load_error_raises_value_error(self, mock_ocp): - """Any exception during checkpoint loading should be re-raised as ValueError.""" - mock_ckptr = MagicMock() - mock_ckptr.metadata.side_effect = RuntimeError("disk on fire") - mock_ocp.Checkpointer.return_value = mock_ckptr - mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() - - cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/bad_ckpt") - with self.assertRaises(ValueError): - model_creation_utils.create_nnx_model(cfg, self.mesh) - - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index c9e4deb725..2cd696f241 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -14,7 +14,6 @@ """Compare expected sharding of models with actual sharding of models.""" -import functools import hashlib import json import os @@ -128,9 +127,6 @@ def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) f"model_name={model_name}", "log_config=false", "debug_sharding=true", # for input sharding dump - "pure_nnx=False", - "enable_nnx=False", - "pure_nnx_decoder=False", ] root_dir = "tests/utils/sharding_info" @@ -219,9 +215,6 @@ def abstract_state_and_shardings(request): f"compile_topology_num_slices={num_slice}", f"model_name={model_name}", "weight_dtype=float32", - "pure_nnx=False", - "enable_nnx=False", - "pure_nnx_decoder=False", ] config = pyconfig.initialize(params) validate_config(config) @@ -235,15 +228,13 @@ def abstract_state_and_shardings(request): tx = optimizers.get_optimizer(config, learning_rate_schedule) rng = jax.random.PRNGKey(0) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) - # Get abstract state and physical shardings from maxtext_utils abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - config, topology_mesh, init_state_fn, is_training=True + model, tx, config, rng, topology_mesh, is_training=True ) # Get logical shardings from maxtext_utils - logical_shardings = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) + logical_shardings = maxtext_utils.get_logical_annotations(model, tx, config, rng, topology_mesh, is_training=True) return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings diff --git a/tests/unit/state_dtypes_test.py b/tests/unit/state_dtypes_test.py index 10db1bf199..77e166193a 100644 --- a/tests/unit/state_dtypes_test.py +++ b/tests/unit/state_dtypes_test.py @@ -13,7 +13,6 @@ # limitations under the License. """ Test that all weights are expected dtype (default float32) """ -from functools import partial import unittest import jax @@ -48,12 +47,7 @@ def get_state(self, argv): tx = optimizers.get_optimizer(config, learning_rate_schedule) _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) - abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, example_rng, mesh) return abstract_state def get_weights(self, argv): diff --git a/tests/unit/train_utils_test.py b/tests/unit/train_utils_test.py deleted file mode 100644 index a8b9458794..0000000000 --- a/tests/unit/train_utils_test.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for train_utils.py.""" - -import unittest -from dataclasses import dataclass -from unittest.mock import MagicMock - -from maxtext.utils.train_utils import validate_train_config, create_training_optimizer - - -@dataclass -class MockConfig: - """Minimal mock config for validate_train_config tests.""" - - run_name: str = "test_run" - dataset_path: str = "gs://test-bucket/data" - base_output_directory: str = "gs://test-bucket/output" - steps: int = 100 - quantization: str = "" - gradient_accumulation_steps: int = 1 - packing: bool = False - dataset_type: str = "tfds" - - # Fields needed for create_training_optimizer - opt_type: str = "adamw" - adam_b1: float = 0.9 - adam_b2: float = 0.95 - adam_eps: float = 1e-8 - adam_eps_root: float = 0.0 - adam_weight_decay: float = 0.1 - mu_dtype: str = "" - learning_rate: float = 1e-4 - learning_rate_schedule_steps: int = 1000 - warmup_steps_fraction: float = 0.1 - cosine_learning_rate_final_fraction: float = 0.0 - steps: int = 100 - lr_schedule_type: str = "cosine" - use_iota_embed: bool = False - - -class TestValidateTrainConfig(unittest.TestCase): - """Tests for validate_train_config.""" - - def test_valid_config_passes(self): - """Verifies no exception raised for a valid config.""" - config = MockConfig() - # Should not raise - validate_train_config(config) - - def test_missing_run_name_raises(self): - """Verifies AssertionError when run_name is empty.""" - config = MockConfig(run_name="") - with self.assertRaises(AssertionError): - validate_train_config(config) - - def test_zero_steps_raises(self): - """Verifies AssertionError when steps is 0.""" - config = MockConfig(steps=0) - with self.assertRaises(AssertionError): - validate_train_config(config) - - def test_negative_steps_raises(self): - """Verifies AssertionError when steps is negative.""" - config = MockConfig(steps=-5) - with self.assertRaises(AssertionError): - validate_train_config(config) - - def test_fp8_with_grad_accumulation_raises(self): - """Verifies AssertionError for fp8 quantization + gradient_accumulation_steps > 1.""" - config = MockConfig(quantization="fp8", gradient_accumulation_steps=2) - with self.assertRaises(AssertionError): - validate_train_config(config) - - def test_nanoo_fp8_with_grad_accumulation_raises(self): - """Verifies AssertionError for nanoo_fp8 quantization + gradient_accumulation_steps > 1.""" - config = MockConfig(quantization="nanoo_fp8", gradient_accumulation_steps=4) - with self.assertRaises(AssertionError): - validate_train_config(config) - - def test_fp8_with_single_grad_accumulation_passes(self): - """Verifies no error for fp8 with gradient_accumulation_steps=1.""" - config = MockConfig(quantization="fp8", gradient_accumulation_steps=1) - validate_train_config(config) # Should not raise - - def test_packing_with_synthetic_data_logs_warning(self): - """Verifies no exception for packing + synthetic (just logs a warning).""" - config = MockConfig(packing=True, dataset_type="synthetic") - # Should not raise - just log a warning - validate_train_config(config) - - def test_local_dataset_path_logs_warning(self): - """Verifies no exception for local dataset_path (just logs a warning).""" - config = MockConfig(dataset_path="/local/path/to/data") - validate_train_config(config) # Should not raise - - def test_local_output_directory_logs_warning(self): - """Verifies no exception for local base_output_directory (just logs a warning).""" - config = MockConfig(base_output_directory="/local/output") - validate_train_config(config) # Should not raise - - -class TestCreateTrainingOptimizer(unittest.TestCase): - """Tests for create_training_optimizer.""" - - def _make_config(self, opt_type="adamw", **kwargs): - """Creates a mock config for optimizer tests.""" - cfg = MockConfig(opt_type=opt_type, **kwargs) - return cfg - - def _mock_lr_schedule(self): - """Returns a mock learning rate schedule that returns a fixed value.""" - return lambda step: 1e-4 - - def test_adamw_optimizer_returns_schedule_and_tx(self): - """Verifies create_training_optimizer returns a schedule and optax transform for adamw.""" - config = MagicMock() - config.opt_type = "adamw" - config.adam_b1 = 0.9 - config.adam_b2 = 0.999 - config.adam_eps = 1e-8 - config.adam_eps_root = 0.0 - config.adam_weight_decay = 0.01 - config.mu_dtype = None - config.learning_rate = 1e-4 - config.warmup_steps_fraction = 0.1 - config.cosine_learning_rate_final_fraction = 0.0 - config.steps = 100 - config.learning_rate_schedule_steps = 100 - config.lr_schedule_type = "cosine" - config.use_iota_embed = False - - schedule, tx = create_training_optimizer(config, model=None) - - self.assertIsNotNone(schedule) - self.assertIsNotNone(tx) - # Verify it's an optax GradientTransformation - self.assertTrue(hasattr(tx, "init")) - self.assertTrue(hasattr(tx, "update")) - - def test_adam_pax_optimizer_returns_tx(self): - """Verifies create_training_optimizer works for adam_pax optimizer.""" - config = MagicMock() - config.opt_type = "adam_pax" - config.adam_b1 = 0.9 - config.adam_b2 = 0.999 - config.adam_eps = 1e-8 - config.adam_eps_root = 0.0 - config.adam_weight_decay = 0.01 - config.mu_dtype = None - config.learning_rate = 1e-4 - config.warmup_steps_fraction = 0.1 - config.cosine_learning_rate_final_fraction = 0.0 - config.steps = 100 - config.learning_rate_schedule_steps = 100 - config.lr_schedule_type = "cosine" - config.use_iota_embed = False - - _, tx = create_training_optimizer(config, model=None) - - self.assertIsNotNone(tx) - self.assertTrue(hasattr(tx, "init")) - self.assertTrue(hasattr(tx, "update")) - - def test_sgd_optimizer_returns_tx(self): - """Verifies create_training_optimizer works for sgd optimizer.""" - config = MagicMock() - config.opt_type = "sgd" - config.learning_rate = 1e-4 - config.warmup_steps_fraction = 0.0 - config.cosine_learning_rate_final_fraction = 0.0 - config.steps = 100 - config.learning_rate_schedule_steps = 100 - config.lr_schedule_type = "cosine" - config.use_iota_embed = False - - _, tx = create_training_optimizer(config, model=None) - - self.assertIsNotNone(tx) - self.assertTrue(hasattr(tx, "init")) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index be9fb05049..4d6f9982b6 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -37,7 +37,6 @@ """Check if the logits generated by a model's src/maxtext/HF implementation matches golden logits for the same inputs""" import argparse -import functools import os from pathlib import Path import sys @@ -252,13 +251,8 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng1) - state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join( @@ -441,13 +435,8 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1) - maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) + maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = []