diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bdec9c1f10..c4f477a971 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -18,6 +18,7 @@ import functools import pickle import os +from typing import Sequence from flax import linen as nn from flax.linen import partitioning as nn_partitioning @@ -27,6 +28,7 @@ from jax.experimental import mesh_utils from jax.experimental.serialize_executable import deserialize_and_load +from jax.sharding import AxisType, Mesh import jax import jax.numpy as jnp @@ -36,7 +38,8 @@ import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.configs import pyconfig +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing @@ -1531,3 +1534,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging all_host_upload=False, # Only upload from lead host (Host 0) ) + + +def get_mesh_from_config( + config: pyconfig.HyperParameters, + devices: Sequence[jax.Device] | None = None, +) -> Mesh: + """ + Geh mesh from the configuration. + + Args: + config: the configuration + devices: the devices + + Returns: + the device mesh + """ + devices_array = create_device_mesh(config, devices) + + if config.shard_mode == ShardMode.EXPLICIT: + axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) + else: + axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) + + return Mesh(devices_array, config.mesh_axes, axis_types=axis_types) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 49fb9d3490..0d547f6d71 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -19,20 +19,18 @@ from collections.abc import Sequence from functools import partial from typing import overload - from etils import epath from flax import nnx import flax.linen as nn import jax import jax.numpy as jnp -from jax.sharding import AxisType, Mesh +from jax.sharding import Mesh from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging -from maxtext.utils import max_utils -from maxtext.utils import maxtext_utils +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx from orbax import checkpoint as ocp try: @@ -154,6 +152,7 @@ def from_config( mesh: Mesh | None = None, *, model_mode: str = MODEL_MODE_TRAIN, + rngs: None = None, ) -> nn.Module: ... @@ -194,15 +193,7 @@ def from_config( model = from_config(config) """ if mesh is None: - devices_array = maxtext_utils.create_device_mesh(config, devices) - - if config.shard_mode == ShardMode.EXPLICIT: - axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) - else: - axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) - - mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types) - + mesh = maxtext_utils.get_mesh_from_config(config, devices) model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) # Return only the model @@ -260,16 +251,10 @@ def _create_model(rng_key=None): def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + is_training = model_mode == MODEL_MODE_TRAIN def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) - - if model_mode == MODEL_MODE_TRAIN: - rngs = nnx.Rngs(params=rng_key, dropout=1) - else: - rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference - + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) @@ -282,6 +267,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, if mesh is None: mesh = abstract_model.mesh + # Note for pure_nnx: + # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and + # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen + # LogicallyPartitioned structure. + # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned + # structure in the abstract state and we can get the sharded state with the following code: + # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) + # abstract_model = nnx.merge(graphdef, state) + # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) + # sharded_state = nnx.state(model) + # JIT a function that creates the model state with proper sharding from the start. # By providing out_shardings, we instruct JAX to produce sharded output directly, # avoiding a large intermediate allocation on a single device. diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py new file mode 100644 index 0000000000..0eb1f7ef77 --- /dev/null +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -0,0 +1,182 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for the common MaxText NNX utilities """ +import unittest +from dataclasses import dataclass +from typing import Any +import jax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils + +from maxtext.utils import maxtext_utils_nnx + + +class TestMaxTextUtilsNNX(unittest.TestCase): + """Test the functions for MaxText Utils.""" + + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" + + init_weights_seed: int = 42 + + class TinyModel(nnx.Module): + """ + A tiny NNX model with logical annotations. + Annotations are required to test that sharding extraction logic works. + """ + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + jax.device_count(), + jax.device_count(), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)), + # FIX: Removed () from zeros. zeros is the initializer function itself, + # not a factory like lecun_normal(). + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)), + rngs=rngs, + ) + + def tiny_model_init_fn(self): + """Factory function for model initialization.""" + return self.TinyModel(rngs=nnx.Rngs(0)) + + def setUp(self): + # Create a mesh for sharding tests. + # NamedSharding requires an active Mesh to resolve logical names. + self.devices = mesh_utils.create_device_mesh((jax.device_count(),)) + self.mesh = Mesh(self.devices, axis_names=("data",)) + + def test_create_nnx_rngs_training(self): + # Using Any to satisfy static type checkers for the MockConfig + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True) + + self.assertIsInstance(rngs, nnx.Rngs) + # FIX: nnx.Rngs does not have a .streams attribute. + # Check for stream attributes directly on the object. + self.assertTrue(hasattr(rngs, "params")) + self.assertTrue(hasattr(rngs, "dropout")) + self.assertTrue(hasattr(rngs, "aqt")) + + def test_create_nnx_rngs_inference(self): + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False) + + self.assertIsInstance(rngs, nnx.Rngs) + # Check that 'params' exists but 'dropout' and 'aqt' were excluded + self.assertTrue(hasattr(rngs, "params")) + self.assertFalse(hasattr(rngs, "dropout")) + self.assertFalse(hasattr(rngs, "aqt")) + + def test_move_memory(self): + sharding = NamedSharding(self.mesh, P("data")) + self.assertNotEqual(sharding.memory_kind, "pinned_host") + + path = ("layers", "linear", "kernel") + host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding) + + self.assertEqual(host_sharding.memory_kind, "pinned_host") + self.assertEqual(host_sharding.spec, P("data")) + + device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding) + + self.assertEqual(device_sharding.memory_kind, "device") + self.assertEqual(device_sharding.spec, P("data")) + + def test_get_set_named_sharding_nnx(self): + # 1. Create the abstract state using standard NNX functional API + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + + # 2. Test extraction + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # Verify kernel and bias match the P("data") annotations from TinyModel + self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) + self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data")) + + # Target kernel spec update + new_kernel_spec = P(None, "data") + + def update_spec_fn(path, leaf_sharding): + path_str = jax.tree_util.keystr(path) + if "linear" in path_str and "kernel" in path_str: + # Construct a new NamedSharding with the requested logical spec + return NamedSharding(leaf_sharding.mesh, new_kernel_spec) + return leaf_sharding + + # Apply the spec change to the extracted sharding tree + extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings) + + # 3. Test setting new shardings + # Transform the extracted shardings to host memory + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings) + + # Verify the metadata inside the abstract state leaf has updated its sharding + self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host") + # Also verify the spec was updated successfully + self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec) + + # 4. Verify named sharding is preserved after NNX merge (update) and split (state) + model = self.tiny_model_init_fn() + nnx.update(model, updated_abstract) + re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + + # Verify kernel and bias have expected sharding + self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) + self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data")) + + def test_create_nnx_sharded_model(self): + # 1. Create abstract model + graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + abstract_model = nnx.merge(graphdef, abstract_state) + + # 2. Modify shardings to trigger host offloading + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + + # 3. Run the sharded creation + # We pass the abstract model and use the custom sharding for instantiation + sharded_model = maxtext_utils_nnx.create_nnx_sharded_model( + abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings + ) + + # 4. Verify the model is concrete (contains Arrays) and sharded on host + self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array) + self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host") + + def test_get_partition_spec_nnx(self): + """Verifies extraction of PartitionSpecs from NamedShardings.""" + # 1. Create abstract state and get sharding + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # 2. Execute extraction + spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) + + # 3. Verify that the leaves are now raw PartitionSpecs + # Expected values derived from TinyModel definition + expected_spec_k = P("data", None) + expected_spec_b = P("data") + + self.assertEqual(spec["linear"]["kernel"], expected_spec_k) + self.assertEqual(spec["linear"]["bias"], expected_spec_b) + self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 4850e972b3..7a09750a86 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,10 +15,11 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any +from typing import Any, Sequence from collections.abc import Callable import unittest -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch +from dataclasses import dataclass, field from flax import linen as nn from flax import nnx @@ -27,9 +28,9 @@ import jax from jax import random, vmap import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec from maxtext.configs import pyconfig -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -922,39 +923,65 @@ def test_wsd_schedule(self): self.assertIn("wsd_decay_steps_fraction", str(cm.exception)) -class TestGetAbstractState(unittest.TestCase): - """Test class for get_abstract_state.""" +class TestMeshUtils(unittest.TestCase): + """Test suite for the mesh creation utility function.""" - def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize( - [None, get_test_config_path()], - **extra_args, - enable_checkpointing=False, - model_name="llama3.1-8b", - per_device_batch_size=1, - max_target_length=16, - ) - devices_array = maxtext_utils.create_device_mesh(self.config) - self.mesh = Mesh(devices_array, self.config.mesh_axes) - quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - self.rng = jax.random.PRNGKey(0) - self.tx = optax.adam(learning_rate=0.001) + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" - def test_get_abstract_state(self): - """Tests that get_abstract_state returns abstract arrays.""" - # get_abstract_state returns a tuple, the first element is the abstract state. - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self.mesh, init_state_fn) - - # Check that params are abstract - param_leaves = jax.tree_util.tree_leaves(abstract_state.params) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in param_leaves)) + init_weights_seed: int = 42 + shard_mode: str = ShardMode.EXPLICIT + mesh_axes: Sequence[str] = field(default_factory=lambda: ["data", "model"]) - # Check that opt_state is abstract - opt_state_leaves = jax.tree_util.tree_leaves(abstract_state.opt_state) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) + def setUp(self): + # Setup a dummy device array for the mock to return + self.devices_array = np.array(jax.devices()) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_explicit_mode(self, mock_create_device_mesh): + """Tests that ShardMode.EXPLICIT sets axis_types to MANUAL.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:1].reshape((1,)) + config = self.MockConfig(shard_mode=ShardMode.EXPLICIT, mesh_axes=["data"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + # Check that the internal utility was called correctly + mock_create_device_mesh.assert_called_once_with(config, None) + + # Verify Mesh properties + self.assertEqual(mesh.axis_names, ("data",)) + # In JAX, AxisType.MANUAL is the equivalent for explicit control + self.assertEqual(mesh.axis_types, (AxisType.Explicit,)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_auto_mode(self, mock_create_device_mesh): + """Tests that ShardMode.AUTO sets axis_types to AUTO.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:2].reshape((2, 1)) + config = self.MockConfig(shard_mode=ShardMode.AUTO, mesh_axes=["data", "model"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + self.assertEqual(len(mesh.axis_types), 2) + self.assertTrue(all(t == AxisType.Auto for t in mesh.axis_types)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_with_provided_devices(self, mock_create_device_mesh): + """Tests that provided devices are passed through to the mesh creator.""" + config = self.MockConfig() + specific_devices = self.devices_array[:2].reshape((1, 2)) + mock_create_device_mesh.return_value = specific_devices + + _ = maxtext_utils.get_mesh_from_config(config, devices=specific_devices) + + # Verify the second argument to create_device_mesh was our device list + mock_create_device_mesh.assert_called_once_with(config, specific_devices) class TestGetFunctionalTrainWithSignature(unittest.TestCase): diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index bed2e699fa..ba4cb8817c 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -346,21 +346,13 @@ def _make_nnx_metadata_mock(self): @patch("maxtext.utils.model_creation_utils.ocp") def test_load_nnx_checkpoint(self, mock_ocp): """NNX-format checkpoint: restored values are wrapped under a 'value' key.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) - _, abstract_state = nnx.split(abstract_model) - - # Build a fake restored dict with 'value' keys (NNX checkpoint structure). - # Use concrete zero arrays (not ShapeDtypeStruct) so device_put in - # _expand_checkpoint_to_model_shapes receives a valid JAX array. - fake_restored = jax.tree.map( - lambda v: {"value": jnp.zeros(v.value.shape, v.value.dtype)}, - abstract_state, - is_leaf=lambda n: isinstance(n, nnx.Variable), - ) - + # Echo back the `item` argument passed by create_nnx_model to ckptr.restore. + # For NNX checkpoints, item IS already {leaf: {"value": array}, ...}, so + # returning it directly gives a correctly-structured restored dict that + # matches the model's own state — regardless of the exact leaf count. mock_ckptr = MagicMock() mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() - mock_ckptr.restore.return_value = fake_restored + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item mock_ocp.Checkpointer.return_value = mock_ckptr mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} @@ -373,22 +365,13 @@ def test_load_nnx_checkpoint(self, mock_ocp): @patch("maxtext.utils.model_creation_utils.ocp") def test_load_linen_checkpoint(self, mock_ocp): """Linen-format checkpoint: restored values are nested under 'params'/'params'.""" - _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, self.mesh) - _, abstract_state = nnx.split(abstract_model) - - # Build fake plain-value dict (Linen structure). - # Use concrete zero arrays so device_put in _expand_checkpoint_to_model_shapes - # receives a valid JAX array (not a ShapeDtypeStruct). - fake_params = jax.tree.map( - lambda v: jnp.zeros(v.value.shape, v.value.dtype), - abstract_state, - is_leaf=lambda n: isinstance(n, nnx.Variable), - ) - fake_restored = {"params": {"params": fake_params}} - + # Echo back the `item` argument passed by create_nnx_model to ckptr.restore. + # For Linen checkpoints, item IS already {"params": {"params": arrays}}, so + # returning it directly gives a correctly-structured restored dict that + # matches the model's own state — regardless of the exact leaf count. mock_ckptr = MagicMock() mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() - mock_ckptr.restore.return_value = fake_restored + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item mock_ocp.Checkpointer.return_value = mock_ckptr mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() mock_ocp.checkpoint_utils.construct_restore_args.return_value = {}