diff --git a/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py new file mode 100644 index 0000000000..c103f234ee --- /dev/null +++ b/src/maxtext/checkpoint_conversion/compare_linen_nnx_checkpoint.py @@ -0,0 +1,609 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Compare checkpoint tree structures, shapes, and values. + +Supports comparing any combination of Linen and NNX checkpoints: +- Linen vs NNX (cross-format comparison) +- Linen vs Linen (same-format comparison) +- NNX vs NNX (same-format comparison) + +The script auto-detects the format of each checkpoint and applies the +appropriate normalization. Cross-format transformations (like layer axis +transposition) are only applied when comparing Linen vs NNX. + +Key differences between Linen and NNX checkpoints: +- Linen: params/params/decoder/layers/0/... (per-layer, double nested) +- NNX: model/decoder/layers/... (stacked layers, single nested, {value: array} wrappers) + +The script handles: +- Double 'params' nesting in Linen checkpoints +- 'model' key in NNX checkpoints (vs 'params' in Linen) +- {value: array} wrappers in NNX checkpoints +- Layer axis transposition (NNX stacks layers along axis 0, only for cross-format) +- RNG filtering (NNX has rngs, Linen doesn't) + +Usage: + # Compare Linen vs NNX (structure and shapes only) + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint/0/items" + + # Compare NNX vs NNX + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/nnx_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/nnx_checkpoint_b/0/items" + + # Compare Linen vs Linen + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/linen_checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/linen_checkpoint_b/0/items" + + # Compare with value checking + python compare_linen_nnx_checkpoint.py \ + --ckpt_path_1="gs://bucket/checkpoint_a/0/items" \ + --ckpt_path_2="gs://bucket/checkpoint_b/0/items" \ + --compare_values --atol=1e-5 --rtol=1e-5 +""" + +import os +from typing import Any, Dict, Sequence + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import jax.numpy as jnp +from jax.tree_util import tree_flatten_with_path, keystr, tree_structure, tree_map_with_path +import numpy as np +from etils import epath +import orbax.checkpoint as ocp +from absl import app +from absl import flags + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "ckpt_path_1", + None, + "Path to the first checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_string( + "ckpt_path_2", + None, + "Path to the second checkpoint items directory. Format is auto-detected.", + required=True, +) +flags.DEFINE_boolean( + "verbose", + False, + "Print detailed per-parameter information.", +) +flags.DEFINE_boolean( + "transpose_nnx_layers", + False, + "Transpose NNX layer params from (layers, ...) to (...) for comparison. " + "NNX stacks layers along axis 0, while Linen stores per-layer params. " + "Only applied for cross-format (Linen vs NNX) comparisons.", +) +flags.DEFINE_string( + "compare_only", + "params", + "Which parts to compare: 'params' for params only, 'all' for full state.", +) +flags.DEFINE_boolean( + "ignore_rngs", + True, + "Ignore RNG-related paths in comparison (NNX has rngs, Linen doesn't).", +) +flags.DEFINE_boolean( + "compare_values", + False, + "Also compare parameter values (not just structure and shapes).", +) +flags.DEFINE_float( + "atol", + 1e-5, + "Absolute tolerance for value comparison.", +) +flags.DEFINE_float( + "rtol", + 1e-5, + "Relative tolerance for value comparison.", +) + + +def log(message: str) -> None: + """Log a message with prefix.""" + print(f"[compare_ckpt] {message}") + + +def is_rng_path(path: str) -> bool: + """Check if a path is RNG-related.""" + path_lower = path.lower() + return "rngs" in path_lower or "rng" in path_lower + + +def filter_rngs(tree: Dict[str, Any]) -> Dict[str, Any]: + """Filter out RNG-related keys from a tree.""" + if not isinstance(tree, dict): + return tree + + result = {} + for key, value in tree.items(): + # Skip RNG-related keys + if is_rng_path(key): + continue + # Recursively filter nested dicts + if isinstance(value, dict): + filtered = filter_rngs(value) + if filtered: # Only add if not empty after filtering + result[key] = filtered + else: + result[key] = value + return result + + +def detect_format(state: dict) -> str: + """Detects checkpoint format from state structure ('linen' or 'nnx'). + + Linen format: + - Top-level keys: ['params', 'opt_state', 'step'] + - params/params/decoder/... (double nested) + + NNX format: + - Top-level keys: ['model', 'optimizer'] (nnx.State style) + - model/decoder/... with {value: array} wrappers + """ + # Check for NNX nnx.State format (has 'model' key instead of 'params') + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Checkpoint does not contain 'params' or 'model' key. Found keys: {list(state.keys())}") + + params = state["params"] + + # Check for Linen's double 'params' nesting + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Check for NNX's flat structure (params/decoder/...) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + return "nnx" + + # Try to detect by looking for {value: array} wrappers (NNX style) + if _has_value_wrappers(params): + return "nnx" + + raise ValueError( + f"Could not detect checkpoint format. params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +def _has_value_wrappers(tree: Any) -> bool: + """Check if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {'value': array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, (np.ndarray, jnp.ndarray)): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _normalize_linen_params(params: dict) -> dict: + """Normalize Linen params by removing double 'params' nesting.""" + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return inner + return params + + +def _normalize_nnx_params(params: dict) -> dict: + """Normalize NNX params by stripping {value: array} wrappers.""" + return _strip_value_wrappers(params) + + +def load_checkpoint(checkpoint_path: str, metadata_only: bool = False) -> dict: + """Loads checkpoint from local or GCS path. + + If metadata_only=True, returns a pytree of ArrayMetadata (shape/dtype only) + without downloading any tensor data. This is fast and sufficient for + structure/shape comparison. + """ + log(f"Loading checkpoint from: {checkpoint_path}") + if metadata_only: + log(" Mode: metadata only (no tensor data downloaded)") + + checkpoint_dir = epath.Path(checkpoint_path) + + # Create checkpointer and get metadata + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + + try: + metadata = ckptr.metadata(checkpoint_dir) + + if metadata_only: + tree = metadata.item_metadata.tree + log(f" Loaded metadata keys: {list(tree.keys())}") + return tree + + # Create a mesh with all available devices for unsharded restoration + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + # Build restore args that restore arrays without original sharding + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + except Exception as e: # pylint: disable=broad-exception-caught + if metadata_only: + log(f" Metadata loading failed: {e}") + raise + # Fallback to simple restore without sharding args + log(f" Falling back to simple restore: {e}") + checkpointer = ocp.PyTreeCheckpointer() + state = checkpointer.restore(checkpoint_path) + + if state is None: + raise ValueError(f"Failed to restore checkpoint from {checkpoint_path}") + + log(f" Loaded keys: {list(state.keys())}") + return state + + +def transform_nnx_params_for_comparison(nnx_params: Dict[str, Any]) -> Dict[str, Any]: + """Transform NNX params to match Linen structure for comparison. + + NNX stacks layer parameters along axis 0 (shape: [num_layers, ...]), + while Linen stores per-layer parameters (shape: [...]). + + This function transposes layer params from (layers, d1, d2, ...) to (d1, layers, d2, ...) + to align with how Linen params would look if stacked. + """ + + def _transform(path, leaf: jax.Array) -> jax.Array: + key_str = keystr(path) + + # Only transform arrays in 'layers' with ndim >= 2 + if "layers" in key_str and hasattr(leaf, "ndim") and leaf.ndim >= 2: + # Transpose from (layers, d1, d2, ...) to (d1, layers, d2, ...) + axes = (1, 0) + tuple(range(2, leaf.ndim)) + result = jnp.transpose(leaf, axes=axes) + if FLAGS.verbose: + log(f" TRANSPOSING: {key_str} shape {leaf.shape} -> {result.shape}") + return result + else: + return leaf + + log("Transforming NNX params (transposing layer dimensions)...") + return tree_map_with_path(_transform, nnx_params) + + +def get_tree_structure_info(tree: Dict[str, Any]) -> Dict[str, tuple]: + """Get structure info as dict of path -> (shape, dtype).""" + flat_with_path, _ = tree_flatten_with_path(tree) + return { + keystr(p): ( + getattr(leaf, "shape", "N/A"), + str(getattr(leaf, "dtype", type(leaf).__name__)), + ) + for p, leaf in flat_with_path + } + + +def print_structure_diff(params1: Dict, params2: Dict, name1: str = "Linen", name2: str = "NNX"): + """Print structural differences between two param trees.""" + info1 = get_tree_structure_info(params1) + info2 = get_tree_structure_info(params2) + keys1, keys2 = set(info1.keys()), set(info2.keys()) + + only_in_1 = sorted(keys1 - keys2) + only_in_2 = sorted(keys2 - keys1) + common = keys1 & keys2 + + if only_in_1: + print(f"\n--- Paths only in {name1} ({len(only_in_1)}) ---") + for k in only_in_1: + shape, dtype = info1[k] + print(f" - {k}: shape={shape}, dtype={dtype}") + + if only_in_2: + print(f"\n--- Paths only in {name2} ({len(only_in_2)}) ---") + for k in only_in_2: + shape, dtype = info2[k] + print(f" + {k}: shape={shape}, dtype={dtype}") + + # Check for shape/dtype mismatches in common paths + shape_mismatches = [] + dtype_mismatches = [] + for k in common: + shape1, dtype1 = info1[k] + shape2, dtype2 = info2[k] + if shape1 != shape2: + shape_mismatches.append((k, shape1, shape2)) + if dtype1 != dtype2: + dtype_mismatches.append((k, dtype1, dtype2)) + + if shape_mismatches: + print(f"\n--- Shape mismatches ({len(shape_mismatches)}) ---") + for k, s1, s2 in shape_mismatches: + print(f" {k}: {name1}={s1}, {name2}={s2}") + + if dtype_mismatches: + print(f"\n--- Dtype mismatches ({len(dtype_mismatches)}) ---") + for k, d1, d2 in dtype_mismatches: + print(f" {k}: {name1}={d1}, {name2}={d2}") + + return only_in_1, only_in_2, shape_mismatches, dtype_mismatches + + +def compare_params( + params1: Dict[str, Any], + params2: Dict[str, Any], + verbose: bool = False, + compare_values: bool = False, + atol: float = 1e-5, + rtol: float = 1e-5, + name1: str = "Ckpt1", + name2: str = "Ckpt2", +) -> bool: + """Compare two parameter trees for structure, shape, and optionally values. + + Returns True if tree structures, shapes, and (optionally) values match. + """ + # First check tree structure + if tree_structure(params1) != tree_structure(params2): + print("\n[✗] Tree structures differ.") + print_structure_diff(params1, params2, name1=name1, name2=name2) + return False + + print("\n[✓] Tree structures are the same.") + + all_match = True + num_params = 0 + shape_mismatches = [] + dtype_mismatches = [] + value_mismatches = [] + value_matches = 0 + + def _compare_leaf(path, x, y): + nonlocal all_match, num_params, shape_mismatches, dtype_mismatches, value_mismatches, value_matches + key_str = keystr(path) + num_params += 1 + + shape1 = getattr(x, "shape", "N/A") + shape2 = getattr(y, "shape", "N/A") + dtype1 = getattr(x, "dtype", type(x).__name__) + dtype2 = getattr(y, "dtype", type(y).__name__) + + # Check shape + shape_match = shape1 == shape2 + if not shape_match: + shape_mismatches.append((key_str, shape1, shape2)) + all_match = False + + # Check dtype + dtype_match = str(dtype1) == str(dtype2) + if not dtype_match: + dtype_mismatches.append((key_str, dtype1, dtype2)) + all_match = False + + # Check values if requested and shapes match + if compare_values and shape_match and hasattr(x, "shape") and hasattr(y, "shape"): + try: + x_arr = np.asarray(x) + y_arr = np.asarray(y) + is_close = bool(np.allclose(x_arr, y_arr, atol=atol, rtol=rtol)) + + if is_close: + value_matches += 1 + if verbose: + print(f" [✓] {key_str} | Shape: {shape1} | Values match") + else: + diff = np.abs(x_arr - y_arr) + mean_diff = float(np.mean(diff)) + max_diff = float(np.max(diff)) + value_mismatches.append((key_str, mean_diff, max_diff)) + all_match = False + if verbose: + print(f" [✗] {key_str} | Shape: {shape1} | Mean diff: {mean_diff:.2e}, Max diff: {max_diff:.2e}") + except Exception as e: # pylint: disable=broad-exception-caught + value_mismatches.append((key_str, f"Error: {e}", "")) + all_match = False + elif verbose and not compare_values: + print(f" {key_str} | Shape: {shape1} | Dtype: {dtype1}") + + tree_map_with_path(_compare_leaf, params1, params2) + + # Print summary + print("\n--- Summary ---") + print(f"Total parameters: {num_params}") + + if shape_mismatches: + print(f"\n[✗] Shape mismatches ({len(shape_mismatches)}):") + for key_str, s1, s2 in shape_mismatches: + print(f" {key_str}: {name1}={s1}, {name2}={s2}") + else: + print("[✓] All shapes match.") + + if dtype_mismatches: + print(f"\n[✗] Dtype mismatches ({len(dtype_mismatches)}):") + for key_str, d1, d2 in dtype_mismatches: + print(f" {key_str}: {name1}={d1}, {name2}={d2}") + else: + print("[✓] All dtypes match.") + + if compare_values: + if value_mismatches: + print(f"\n[✗] Value mismatches ({len(value_mismatches)}):") + for item in value_mismatches[:20]: # Show first 20 + if len(item) == 3: + key_str, mean_diff, max_diff = item + if isinstance(mean_diff, float): + print(f" {key_str}: mean_diff={mean_diff:.2e}, max_diff={max_diff:.2e}") + else: + print(f" {key_str}: {mean_diff}") + if len(value_mismatches) > 20: + print(f" ... and {len(value_mismatches) - 20} more (use --verbose to see all)") + else: + print(f"[✓] All values match (atol={atol}, rtol={rtol}).") + print(f" Values matching: {value_matches}/{num_params}") + + return all_match + + +def _extract_params(state: dict, fmt: str) -> dict: + """Extract params from a checkpoint state based on its detected format.""" + if fmt == "linen": + return state.get("params", {}) + else: + # NNX format: params are in 'model' key + return state.get("model", state.get("params", {})) + + +def _normalize_params(params: dict, fmt: str) -> dict: + """Normalize params based on detected format.""" + if fmt == "linen": + return _normalize_linen_params(params) + else: + return _normalize_nnx_params(params) + + +def main(argv: Sequence[str]): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + ckpt_path_1 = FLAGS.ckpt_path_1 + ckpt_path_2 = FLAGS.ckpt_path_2 + + print("=" * 80) + print("Checkpoint Comparator") + print("=" * 80) + + print(f"\nCheckpoint 1: {ckpt_path_1}") + print(f"Checkpoint 2: {ckpt_path_2}") + print(f"Transpose NNX layers: {FLAGS.transpose_nnx_layers}") + print(f"Ignore RNGs: {FLAGS.ignore_rngs}") + print(f"Compare values: {FLAGS.compare_values}") + if FLAGS.compare_values: + print(f" Tolerance: atol={FLAGS.atol}, rtol={FLAGS.rtol}") + + # Load checkpoints — use metadata-only when not comparing values to avoid + # downloading tensor data (which can be 100+ GiB and cause XPK timeouts). + metadata_only = not FLAGS.compare_values + print("\n" + "-" * 40) + state_1 = load_checkpoint(ckpt_path_1, metadata_only=metadata_only) + state_2 = load_checkpoint(ckpt_path_2, metadata_only=metadata_only) + + # Detect formats + format_1 = detect_format(state_1) + format_2 = detect_format(state_2) + log(f"Detected checkpoint 1 format: {format_1}") + log(f"Detected checkpoint 2 format: {format_2}") + + is_cross_format = format_1 != format_2 + name_1 = f"Ckpt1({format_1})" + name_2 = f"Ckpt2({format_2})" + + # Extract and normalize params + print("\n" + "-" * 40) + log("Normalizing parameters...") + + if FLAGS.compare_only == "params": + params_1 = _extract_params(state_1, format_1) + params_2 = _extract_params(state_2, format_2) + else: + params_1 = state_1 + params_2 = state_2 + + params_1 = _normalize_params(params_1, format_1) + log(f" Checkpoint 1 ({format_1}): normalized") + params_2 = _normalize_params(params_2, format_2) + log(f" Checkpoint 2 ({format_2}): normalized") + + # Filter out RNG paths if requested + if FLAGS.ignore_rngs: + print("\n" + "-" * 40) + log("Filtering out RNG-related paths...") + params_1 = filter_rngs(params_1) + params_2 = filter_rngs(params_2) + + # Transform NNX params for cross-format comparison (transpose layer dimensions) + # Only apply when comparing Linen vs NNX, not for same-format comparisons + if FLAGS.transpose_nnx_layers and is_cross_format: + print("\n" + "-" * 40) + if format_1 == "nnx": + params_1 = transform_nnx_params_for_comparison(params_1) + if format_2 == "nnx": + params_2 = transform_nnx_params_for_comparison(params_2) + + # Compare + print("\n" + "-" * 40) + log("Comparing parameters...") + + success = compare_params( + params_1, + params_2, + verbose=FLAGS.verbose, + compare_values=FLAGS.compare_values, + atol=FLAGS.atol, + rtol=FLAGS.rtol, + name1=name_1, + name2=name_2, + ) + + # Final verdict + print("\n" + "=" * 80) + if success: + print("CHECKPOINTS MATCH") + if FLAGS.compare_values: + print(" Tree structure, shapes, and values are identical!") + else: + print(" Tree structure and all shapes are identical!") + else: + print("CHECKPOINTS DIFFER") + print(" See details above for mismatches.") + print("=" * 80) + + return 0 if success else 1 + + +if __name__ == "__main__": + app.run(main) diff --git a/src/maxtext/checkpoint_conversion/linen_nnx_converter.py b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py new file mode 100644 index 0000000000..015d3b5a56 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/linen_nnx_converter.py @@ -0,0 +1,581 @@ +# Copyright 2023-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bidirectional conversion between Linen and NNX checkpoint formats. + +Top-level key mapping: + Linen → NNX: + params/params/ → model/ (remove double-nesting, rename, add {value:} wrappers) + opt_state → optimizer/opt_state (remove 'params' level from mu/nu) + step → optimizer/step (move inside optimizer) + + NNX → Linen: + model/ → params/params/ (strip {value:} wrappers, add double-nesting) + optimizer/opt_state → opt_state (add 'params' level to mu/nu) + optimizer/step → step (move to top level) + +Layer structure (--scan_layers): + linen_to_nnx: + scan_layers=True (default): stack layers_N arrays → 'layers' tensor with layer dim at axis 1 + scan_layers=False: rename layers_N → integer-keyed 'layers/{N}' + + nnx_to_linen (auto-detected): + Stacked 'layers' tensor → unstack along axis 1 → layers_N per-layer arrays + Integer-keyed layers/{N} → rename to layers_N + +Usage: + python linen_nnx_converter.py \\ + --source_path="gs://bucket/checkpoint/0/items" \\ + --target_path="gs://bucket/converted/" \\ + --direction=auto +""" + +import argparse +import os +import re +import time +from typing import Any + +# MUST set before importing JAX to force CPU-only mode +os.environ["JAX_PLATFORMS"] = "cpu" + +import jax +import numpy as np +from etils import epath +import orbax.checkpoint as ocp + + +def log(message: str) -> None: + print(f"[linen_nnx_converter] {message}") + + +# ── Format detection ─────────────────────────────────────────────────────────── + + +def detect_format(state: dict) -> str: + """Detects checkpoint format ('linen' or 'nnx') from top-level keys.""" + # NNX: uses 'model' as the top-level params key + if "model" in state: + return "nnx" + + if "params" not in state: + raise ValueError(f"Cannot detect checkpoint format: no 'model' or 'params' key. " f"Found: {list(state.keys())}") + + params = state["params"] + + # Linen: double-nested params/params/decoder + if isinstance(params, dict) and "params" in params: + inner = params["params"] + if isinstance(inner, dict) and ("decoder" in inner or "encoder" in inner): + return "linen" + + # Old NNX format: params/decoder (single-nested with value wrappers) + if isinstance(params, dict) and ("decoder" in params or "encoder" in params): + if _has_value_wrappers(params): + return "nnx" + + if "optimizer" in state: + return "nnx" + if "opt_state" in state: + return "linen" + + raise ValueError( + f"Could not detect checkpoint format. Keys: {list(state.keys())}, " + f"params keys: {list(params.keys()) if isinstance(params, dict) else type(params)}" + ) + + +# ── Value wrapper helpers ────────────────────────────────────────────────────── + + +def _has_value_wrappers(tree: Any) -> bool: + """Returns True if tree contains {value: array} wrappers (NNX style).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return True + for v in tree.values(): + if _has_value_wrappers(v): + return True + return False + + +def _strip_value_wrappers(tree: Any) -> Any: + """Recursively strips {value: array} wrappers from a tree.""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + return {k: _strip_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_strip_value_wrappers(item) for item in tree) + else: + return tree + + +def _add_value_wrappers(tree: Any) -> Any: + """Recursively wraps leaf arrays in {value: array} (NNX nnx.Param format).""" + if isinstance(tree, dict): + if set(tree.keys()) == {"value"}: + inner = tree["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return tree # Already wrapped + return {k: _add_value_wrappers(v) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_add_value_wrappers(item) for item in tree) + elif hasattr(tree, "shape") or isinstance(tree, np.ndarray): + return {"value": tree} + else: + return tree + + +# ── Layer structure helpers ──────────────────────────────────────────────────── + + +def _stack_layers(decoder: dict) -> tuple[dict, bool]: + """Stacks per-layer parameters (layers_N) into a single 'layers' dict at axis 0. + + Returns (result_dict, was_stacked). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder, False + + sorted_indices = sorted(layer_indices.keys()) + num_layers = len(sorted_indices) + log(f" Found {num_layers} individual layers, stacking into 'layers'") + + def stack_arrays(layers_data: list) -> Any: + first = layers_data[0] + if hasattr(first, "shape") or isinstance(first, np.ndarray): + return np.stack([np.asarray(layers_data[i]) for i in range(len(layers_data))], axis=0) + elif isinstance(first, dict): + result = {} + for key in first.keys(): + child_data = [layers_data[i].get(key) for i in range(len(layers_data))] + if all(c is not None for c in child_data): + result[key] = stack_arrays(child_data) + return result + else: + return first + + layers_data = [layer_indices[i] for i in sorted_indices] + stacked = stack_arrays(layers_data) + + result = dict(other_keys) + result["layers"] = stacked + return result, True + + +def _rename_layers_to_integer_keys(decoder: dict) -> dict: + """Converts layers_N keys to integer-keyed dict under 'layers' (no stacking). + + Converts {layers_0: {...}, layers_1: {...}} → {layers: {'0': {...}, '1': {...}}}. + Used for scan_layers=False linen→nnx conversion (Pattern C). + """ + layer_pattern = re.compile(r"^layers_(\d+)$") + layer_indices = {} + other_keys = {} + + for key, value in decoder.items(): + match = layer_pattern.match(key) + if match: + layer_indices[int(match.group(1))] = value + else: + other_keys[key] = value + + if not layer_indices: + return decoder + + sorted_indices = sorted(layer_indices.keys()) + log(f" Found {len(sorted_indices)} individual layers, renaming to integer-keyed 'layers/N'") + result = dict(other_keys) + result["layers"] = {str(i): layer_indices[i] for i in sorted_indices} + return result + + +def _transpose_layers_axes(tree: Any, src_axis: int, dst_axis: int) -> Any: + """Transposes the layers dimension in arrays within a tree (src_axis ↔ dst_axis).""" + if src_axis == dst_axis: + return tree + if isinstance(tree, dict): + return {k: _transpose_layers_axes(v, src_axis, dst_axis) for k, v in tree.items()} + elif isinstance(tree, (list, tuple)): + return type(tree)(_transpose_layers_axes(item, src_axis, dst_axis) for item in tree) + elif hasattr(tree, "shape") and len(tree.shape) >= 2: + axes = list(range(len(tree.shape))) + axes[src_axis], axes[dst_axis] = axes[dst_axis], axes[src_axis] + result = np.transpose(np.asarray(tree), axes=axes) + log(f" Transposed: {tree.shape} → {result.shape}") + return result + else: + return tree + + +def _detect_num_layers(tree: Any, scan_axis: int) -> int | None: + """Detects num_layers from the first array with ndim > scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + shape = getattr(tree, "shape", None) or np.asarray(tree).shape + if len(shape) > scan_axis: + return shape[scan_axis] + return None + if isinstance(tree, dict): + for v in tree.values(): + result = _detect_num_layers(v, scan_axis) + if result is not None: + return result + return None + + +def _unstack_single_layer(tree: Any, idx: int, scan_axis: int) -> Any: + """Extracts a single layer by indexing at scan_axis.""" + if hasattr(tree, "shape") or isinstance(tree, np.ndarray): + arr = np.asarray(tree) + if arr.ndim > scan_axis: + return np.take(arr, idx, axis=scan_axis) + return arr + if isinstance(tree, dict): + return {k: _unstack_single_layer(v, idx, scan_axis) for k, v in tree.items()} + if isinstance(tree, (list, tuple)): + return type(tree)(_unstack_single_layer(v, idx, scan_axis) for v in tree) + return tree + + +def _convert_layers_to_linen_format(decoder: dict) -> dict: + """Converts NNX 'layers' back to Linen's layers_N format (auto-detects NNX style). + + Handles: + - Stacked tensor (Pattern B): layers/ + → layers_0, layers_1, ... (unstack along axis 1) + - Integer-keyed (Pattern C): layers/0, layers/1, ... + → layers_0, layers_1, ... (rename) + """ + if "layers" not in decoder: + return decoder + + layers_val = decoder["layers"] + other_keys = {k: v for k, v in decoder.items() if k != "layers"} + + if not isinstance(layers_val, dict): + # Already a non-dict (shouldn't happen normally), keep as-is + return decoder + + # Pattern C: integer-keyed per-layer dict → rename + if all(k.isdigit() for k in layers_val.keys()): + result = dict(other_keys) + for idx_str, layer_data in sorted(layers_val.items(), key=lambda x: int(x[0])): + result[f"layers_{idx_str}"] = layer_data + log(f" Renamed integer-keyed layers/N → layers_N ({len(layers_val)} layers)") + return result + + # Pattern B: stacked tensor (layer dim at axis 1) → unstack + num_layers = _detect_num_layers(layers_val, scan_axis=1) + if num_layers is None: + log(" WARNING: Could not detect num_layers for unstacking, keeping 'layers' as-is") + result = dict(other_keys) + result["layers"] = layers_val + return result + + result = dict(other_keys) + for i in range(num_layers): + result[f"layers_{i}"] = _unstack_single_layer(layers_val, idx=i, scan_axis=1) + log(f" Unstacked scanned 'layers' → layers_N ({num_layers} layers at axis 1)") + return result + + +# ── Optimizer state helpers ──────────────────────────────────────────────────── + + +def _convert_opt_state_linen_to_nnx(opt_state: Any) -> Any: + """Removes 'params' nesting from mu/nu in linen opt_state. + + NNX optimizer state has plain arrays (no {value:} wrappers). + Linen opt_state mirrors the params structure (params/decoder/...), + so we remove the 'params' level to get decoder/... directly. + """ + if isinstance(opt_state, dict): + result = {} + for k, v in opt_state.items(): + if k == "params": + # Remove this level by merging its contents up + converted = _convert_opt_state_linen_to_nnx(v) + if isinstance(converted, dict): + result.update(converted) + else: + result[k] = converted + else: + result[k] = _convert_opt_state_linen_to_nnx(v) + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_linen_to_nnx(item) for item in opt_state) + else: + return opt_state # Plain array or scalar — no value wrapper for opt_state + + +def _convert_opt_state_nnx_to_linen(opt_state: Any, depth: int = 0) -> Any: + """Adds 'params' nesting to mu/nu, removes any stray {value:} wrappers. + + NNX optimizer mu/nu contains decoder/... directly. + Linen expects mu/params/decoder/... (one 'params' level mirroring the params structure). + """ + if isinstance(opt_state, dict): + # Strip any {value:} wrappers in opt_state (shouldn't be there but handle gracefully) + if set(opt_state.keys()) == {"value"}: + inner = opt_state["value"] + if hasattr(inner, "shape") or isinstance(inner, np.ndarray): + return inner + + result = {} + for k, v in opt_state.items(): + converted = _convert_opt_state_nnx_to_linen(v, depth + 1) + # Add one 'params' level after mu/nu (mirrors linen's params structure) + if k in ("mu", "nu") and isinstance(converted, dict): + result[k] = {"params": converted} + else: + result[k] = converted + return result + elif isinstance(opt_state, (list, tuple)): + return type(opt_state)(_convert_opt_state_nnx_to_linen(item, depth + 1) for item in opt_state) + else: + return opt_state + + +# ── Main conversion functions ────────────────────────────────────────────────── + + +def convert_linen_to_nnx(state: dict, scan_layers: bool = True) -> dict: + """Converts Linen checkpoint to NNX format. + + Args: + state: Linen checkpoint dict with keys ['params', 'opt_state', 'step']. + scan_layers: If True (default), stack per-layer arrays and insert layer + dim at axis 1 (for NNX with scan_layers=True). + If False, rename layers_N → integer-keyed layers/N + (for NNX with scan_layers=False). + """ + result = {} + + if "params" in state: + linen_params = state["params"] + # Remove double 'params' nesting: params/params/decoder → decoder + if isinstance(linen_params, dict) and "params" in linen_params: + nnx_params = linen_params["params"] + log(" params: Removed double 'params' nesting (params/params → model)") + else: + nnx_params = linen_params + log(" params: No double nesting found") + + stripped = _strip_value_wrappers(nnx_params) + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + if scan_layers: + stripped[component], was_stacked = _stack_layers(stripped[component]) + if was_stacked and "layers" in stripped[component]: + log(f" {component}/layers: Transposing stacked (layers, ...) → (..., layers, ...) at axis 1") + stripped[component]["layers"] = _transpose_layers_axes(stripped[component]["layers"], src_axis=0, dst_axis=1) + else: + stripped[component] = _rename_layers_to_integer_keys(stripped[component]) + + result["model"] = _add_value_wrappers(stripped) + log(" model: Saved with {value:} wrappers under 'model' key") + + # optimizer: move step inside, keep opt_state + optimizer_dict = {} + if "step" in state: + optimizer_dict["step"] = state["step"] + log(f" optimizer/step: Moved from top-level (step={state['step']})") + if "opt_state" in state: + optimizer_dict["opt_state"] = _convert_opt_state_linen_to_nnx(state["opt_state"]) + log(" optimizer/opt_state: Removed 'params' nesting from mu/nu") + if optimizer_dict: + result["optimizer"] = optimizer_dict + + return result + + +def convert_nnx_to_linen(state: dict) -> dict: + """Converts NNX checkpoint to Linen format. + + Reads from 'model'/'optimizer' keys (or falls back to old 'params'/'opt_state' format). + Layer structure is auto-detected (stacked vs integer-keyed). + """ + result = {} + + model_key = "model" if "model" in state else "params" + if model_key in state: + nnx_params = state[model_key] + stripped = _strip_value_wrappers(nnx_params) + log(f" {model_key}: Removed {{value:}} wrappers") + + for component in ("decoder", "encoder"): + if component in stripped and isinstance(stripped[component], dict): + stripped[component] = _convert_layers_to_linen_format(stripped[component]) + + # Add double 'params' nesting: decoder → params/params/decoder + result["params"] = {"params": stripped} + log(" params: Added double 'params' nesting (model → params/params)") + + # optimizer: extract step and opt_state back to top level + if "optimizer" in state: + optimizer = state["optimizer"] + if "step" in optimizer: + result["step"] = optimizer["step"] + log(" step: Extracted from optimizer/step to top level") + if "opt_state" in optimizer: + result["opt_state"] = _convert_opt_state_nnx_to_linen(optimizer["opt_state"]) + log(" opt_state: Added 'params' nesting to mu/nu") + elif "opt_state" in state: + # Backward compat: old format with opt_state at top level + result["opt_state"] = _convert_opt_state_nnx_to_linen(state["opt_state"]) + log(" opt_state: Converted from top-level opt_state (old format)") + + if "step" in state and "step" not in result: + result["step"] = state["step"] + + return result + + +# ── Checkpoint I/O ───────────────────────────────────────────────────────────── + + +def load_checkpoint(checkpoint_path: str) -> dict: + """Loads checkpoint from local or GCS path.""" + log(f"Loading checkpoint from: {checkpoint_path}") + + checkpoint_dir = epath.Path(checkpoint_path) + ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler()) + metadata = ckptr.metadata(checkpoint_dir) + + devices = np.array(jax.devices()).reshape((-1,)) + single_device_mesh = jax.sharding.Mesh(devices, ("x",)) + unsharded = jax.sharding.NamedSharding(single_device_mesh, jax.sharding.PartitionSpec()) + + restore_args = jax.tree_util.tree_map( + lambda x: ocp.ArrayRestoreArgs(sharding=unsharded) if hasattr(x, "shape") else None, + metadata.item_metadata.tree, + is_leaf=lambda x: hasattr(x, "shape"), + ) + + state = ckptr.restore(checkpoint_dir, restore_args=restore_args) + log(f" Loaded keys: {list(state.keys())}") + return state + + +def save_checkpoint(state: dict, output_path: str) -> None: + """Saves checkpoint to local or GCS path.""" + log(f"Saving checkpoint to: {output_path}") + + output_dir = epath.Path(output_path) + output_dir.mkdir(exist_ok=True, parents=True) + + ckptr = ocp.PyTreeCheckpointer() + ckptr.save(output_dir, state, force=True) + log(" Checkpoint saved successfully") + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between Linen and NNX checkpoint formats.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--source_path", + type=str, + required=True, + help="Path to source checkpoint items directory (e.g. gs://bucket/ckpt/0/items).", + ) + parser.add_argument( + "--target_path", + type=str, + required=True, + help="Path to save converted checkpoint.", + ) + parser.add_argument( + "--direction", + type=str, + choices=["auto", "linen_to_nnx", "nnx_to_linen"], + default="auto", + help="Conversion direction. 'auto' detects from source format.", + ) + parser.add_argument( + "--scan_layers", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "For linen_to_nnx only: if True (default), stack per-layer arrays into a " + "scanned 'layers' tensor with layer dim at axis 1 (for NNX with scan_layers=True). " + "If False, rename layers_N to integer-keyed layers/N without stacking " + "(for NNX with scan_layers=False)." + ), + ) + + args = parser.parse_args() + + print("=" * 80) + print("Linen <-> NNX Checkpoint Converter") + print("=" * 80) + + start_time = time.time() + + state = load_checkpoint(args.source_path) + + if args.direction == "auto": + source_format = detect_format(state) + target_format = "nnx" if source_format == "linen" else "linen" + log(f"Auto-detected: {source_format} → {target_format}") + else: + source_format = args.direction.split("_to_")[0] + target_format = args.direction.split("_to_")[1] + log(f"Using specified direction: {source_format} → {target_format}") + + log(f"Converting: {source_format} → {target_format}") + if source_format == "linen": + log(f"scan_layers={args.scan_layers}") + + if source_format == "linen" and target_format == "nnx": + converted_state = convert_linen_to_nnx(state, scan_layers=args.scan_layers) + elif source_format == "nnx" and target_format == "linen": + converted_state = convert_nnx_to_linen(state) + else: + raise ValueError(f"Invalid conversion: {source_format} → {target_format}") + + save_checkpoint(converted_state, args.target_path) + + elapsed = time.time() - start_time + print("\n" + "=" * 80) + print(f"Conversion complete in {elapsed:.2f} seconds") + print(f" Source: {args.source_path}") + print(f" Target: {args.target_path}") + print(f" Direction: {source_format} → {target_format}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 888cf4d2d1..9b5f0cfb21 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -35,6 +35,7 @@ """ import argparse +import functools import gc import os import sys @@ -87,7 +88,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name mesh = Mesh(devices_array, cfg.mesh_axes) quant = quantizations.configure_quantization(cfg) - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if cfg.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -98,7 +102,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, cfg, init_rng, mesh, checkpoint_manager) + if cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index cdfde92d50..f9b5af575c 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -532,7 +533,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -600,8 +601,13 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) - checkpoint_args = ocp.args.PyTreeRestore(item=abstract_unboxed_pre_state, restore_args=restore_args) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) + checkpoint_args = ocp.args.PyTreeRestore(item=restore_target, restore_args=restore_args) match (checkpoint_manager, dataset_type, data_iterator): # Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager @@ -636,9 +642,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -730,7 +741,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 06063c5fcf..3aac1ca358 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -544,9 +544,13 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ['dense_layers', []], ['moe_layers', []], + ['layers_outside_pipeline', []], + ['layers_per_stage', []], ['engram_dim', ['tensor']], ['mhc', []], ['diloco', 'diloco'], + ['num_activations', []], + ['circular_repeats', []], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] @@ -1126,8 +1130,9 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: false -pure_nnx_decoder: false +enable_nnx: True +pure_nnx_decoder: True +pure_nnx: True ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 8209dece2d..2e37ead72b 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -72,4 +72,32 @@ logical_axis_rules: [ ['exp_with_fsdp', 'fsdp'], ['paged_kv_heads', ['tensor']], ['engram_dim', ['tensor']], + # Axes unsharded: sequence/context/tensor_transpose/autoregressive do not exist in this mesh + ['activation_attn_length_no_exp', []], + ['activation_length_no_exp', []], + ['activation_norm_length', []], + ['activation_q_length_no_exp', []], + ['prefill_activation_length', []], + ['prefill_activation_norm_length', []], + ['activation_kv_length', []], + ['decode_length', []], + ['embed_tensor_transpose', []], + ['q_lora_up_proj', []], + ['kv_lora_up_proj', []], + ['kv', []], + ['qkv', []], + ['kv_head_dim', []], + ['cache_batch_prefill', []], + ['cache_batch', []], + ['cache_heads_none', []], + ['cache_kv', []], + ['cache_sequence', []], + ['num_pages', []], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], + ['num_activations', []], + ['mhc', []], + ['diloco', []], ] diff --git a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml index c8a28c5b24..72503feb29 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pure-fsdp.yml @@ -34,4 +34,57 @@ logical_axis_rules: [ ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['exp_with_fsdp', 'fsdp'], + # All other axes are unsharded (tensor/sequence/expert axes do not exist in pure-fsdp) + ['activation_heads', []], + ['activation_kv_heads', []], + ['activation_length', []], + ['activation_attn_length', []], + ['activation_attn_length_no_exp', []], + ['activation_length_no_exp', []], + ['activation_norm_length', []], + ['activation_q_length', []], + ['activation_q_length_no_exp', []], + ['prefill_activation_length', []], + ['prefill_activation_norm_length', []], + ['activation_kv_length', []], + ['activation_attn_embed', []], + ['activation_embed', []], + ['activation_mlp', []], + ['activation_kv', []], + ['activation_kv_head_dim', []], + ['activation_vocab', []], + ['activation_stage', []], + ['activation_exp', []], + ['decode_length', []], + ['mlp', []], + ['mlp_no_fsdp', []], + ['vocab', []], + ['heads', []], + ['q_heads', []], + ['kv_heads', []], + ['embed_tensor_transpose', []], + ['q_lora_up_proj', []], + ['kv_lora_up_proj', []], + ['norm', []], + ['layers', []], + ['qkv', []], + ['kv', []], + ['kv_head_dim', []], + ['cache_batch_prefill', []], + ['cache_batch', []], + ['cache_heads_none', []], + ['cache_heads', []], + ['cache_kv', []], + ['cache_sequence', []], + ['exp', []], + ['paged_kv_heads', []], + ['num_pages', []], + ['tokens_per_page', []], + ['paged_kv_head_dim_size', []], + ['dense_layers', []], + ['moe_layers', []], + ['num_activations', []], + ['engram_dim', []], + ['mhc', []], + ['diloco', []], ] diff --git a/src/maxtext/configs/decoupled_base_test.yml b/src/maxtext/configs/decoupled_base_test.yml index 07fcaea678..7d6389738e 100644 --- a/src/maxtext/configs/decoupled_base_test.yml +++ b/src/maxtext/configs/decoupled_base_test.yml @@ -1,6 +1,7 @@ # Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml. -# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable -# optional cloud features. +# Inherits from base.yml so that logical_axis_rules, mesh_axes, NNX flags, and all other +# model defaults are kept in sync. Overrides only cloud-coupled paths and optional cloud features. +base_config: base.yml # Output goes to a local relative directory so tests do not require GCS. base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs @@ -34,34 +35,9 @@ attention: "dot_product" dump_hlo: false jax_cache_dir: "" -# Neutral parallelism (single device) for local tests. -ici_data_parallelism: 1 -ici_tensor_parallelism: 1 -ici_pipeline_parallelism: 1 -ici_expert_parallelism: 1 -ici_sequence_parallelism: 1 -ici_context_parallelism: 1 -ici_tensor_transpose_parallelism: 1 -ici_tensor_sequence_parallelism: 1 -ici_autoregressive_parallelism: 1 -ici_fsdp_parallelism: 1 -ici_fsdp_transpose_parallelism: 1 # Allow higher unsharded parameter percentage for small device count sharding_tolerance: 0.3 -# DCN dimensions to 1 (no multi-slice expectation locally). -dcn_data_parallelism: 1 -dcn_tensor_parallelism: 1 -dcn_pipeline_parallelism: 1 -dcn_expert_parallelism: 1 -dcn_sequence_parallelism: 1 -dcn_context_parallelism: 1 -dcn_tensor_transpose_parallelism: 1 -dcn_tensor_sequence_parallelism: 1 -dcn_autoregressive_parallelism: 1 -dcn_fsdp_parallelism: 1 -dcn_fsdp_transpose_parallelism: 1 - # Config logging off unless a test overrides. log_config: false diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3f56bf7406..846098b2d4 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -822,6 +822,7 @@ class HardwareAndMesh(BaseModel): optimize_mesh_for_tpu_v6e: bool = Field(False, description="Apply transformations to the mesh for TPU v6e.") shardy: bool = Field(True, description="Whether to use shardy XLA backend.") pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.") + pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.") class LayoutAndSharding(BaseModel): diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 100434ef74..28eef21cb0 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -546,23 +546,43 @@ def setup_train_loop( max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - model = mt.from_config(config, devices=training_devices) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - inference_model = mt.from_config(config_inference, devices=inference_devices) + if config_inference.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = train_utils.create_training_tools(config, model, mesh) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator = grpo_input_pipeline.create_data_iterator(config_inference, inference_mesh) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) # create inference_state_mesh_shardings from inference_mesh + if config_inference.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( - inference_model, tx, config_inference, init_rng, inference_mesh, is_training=False + config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage @@ -697,7 +717,7 @@ def train_loop(config, config_inference, recorder, state=None): data_buffer = [] data_buffer_lock = threading.Lock() - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) inference_prof = profiler.Profiler(config_inference, offset_step=start_step) data_loader = DataLoader(config_inference, inference_mesh, data_iterator, recorder) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 02a2f392c2..23cd2387db 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -113,7 +113,10 @@ def __init__(self, config: Any, devices: Any | None = None): # Model and Optimizer definition quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -229,17 +232,25 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( - self.model, None, self.config, rng, self._mesh, False + self.config, self._mesh, init_state_fn, False ) # reshard given params based on shardings from config in MaxEngine params = jax.device_put(params, state_mesh_shardings.params) state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - state, self.state_mesh_annotations = maxtext_utils.setup_decode_state( - self.model, self.config, rng1, self._mesh, None - ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index e699b4ac4e..ecfcc19b77 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -549,14 +549,14 @@ def __init__( elif self.is_qwen3_next: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/embeddings.py b/src/maxtext/layers/embeddings.py index 652718e0eb..d61115ecc2 100644 --- a/src/maxtext/layers/embeddings.py +++ b/src/maxtext/layers/embeddings.py @@ -151,10 +151,11 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array: if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError("Input type must be an integer or unsigned integer.") - embedding = jnp.asarray( - _maybe_move_embedding_to_device(self.embedding.value, self.config), - self.dtype, - ) + embedding_val = _maybe_move_embedding_to_device(self.embedding.value, self.config) + if isinstance(embedding_val, jax.ShapeDtypeStruct): + embedding = embedding_val + else: + embedding = jnp.asarray(embedding_val, self.dtype) output_axis_names = ( ( diff --git a/src/maxtext/layers/linears.py b/src/maxtext/layers/linears.py index 4af9c5c530..14ce36b6b3 100644 --- a/src/maxtext/layers/linears.py +++ b/src/maxtext/layers/linears.py @@ -220,32 +220,48 @@ def __call__(self, inputs: Array, _initializing: bool = False, out_sharding: Nam kernel_shape = self.in_features_shape + self.out_features_shape kernel = jnp.zeros(kernel_shape, dtype=self.dtype) else: - kernel = self.kernel[...] - # Move logit_dense kernel to device if parameter offloading is enabled - if self.parameter_memory_host_offload: - max_logging.log("linear.py: Moving parameter logits_dense kernel to device") - kernel = jax.device_put(kernel, max_utils.device_space()) - kernel = jnp.asarray(kernel, self.dtype) + kernel_val = self.kernel.value + if kernel_val is not None: + if isinstance(kernel_val, jax.ShapeDtypeStruct): + # Bypass concrete indexing for abstract tracers + kernel = kernel_val + else: + kernel = self.kernel[...] + # Move logit_dense kernel to device if parameter offloading is enabled + if self.parameter_memory_host_offload: + max_logging.log("linear.py: Moving parameter logits_dense kernel to device") + kernel = jax.device_put(kernel, max_utils.device_space()) + kernel = jnp.asarray(kernel, self.dtype) + else: + kernel = None # out_sharding should be None for auto mesh axis if self.shard_mode != ShardMode.EXPLICIT: out_sharding = None - contract_ind = tuple(range(0, len(self.axis))) - output = _compute_dot_general_nnx( - inputs, - kernel, - norm_axis, - contract_ind, - self.matmul_precision, - self.quant_dot_general, - _initializing, - out_sharding, - ) + if kernel is not None: + contract_ind = tuple(range(0, len(self.axis))) + output = _compute_dot_general_nnx( + inputs, + kernel, + norm_axis, + contract_ind, + self.matmul_precision, + self.quant_dot_general, + _initializing, + out_sharding, + ) + + if self.bias is not None: + bias_val = self.bias.value + if bias_val is not None: + bias = jnp.asarray(self.bias[...], self.dtype) + output += bias + else: + # If kernel is missing (e.g. masked in pipeline), return zeros. + out_shape = inputs.shape[: -len(self.axis)] + self.out_features_shape + output = jnp.zeros(out_shape, dtype=self.dtype) - if self.bias is not None: - bias = jnp.asarray(self.bias[...], self.dtype) - output += bias return output diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 314c450b03..1002ab957a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -273,25 +273,35 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. kernel_shape = self.in_features_shape + self.out_features_shape kernel = jnp.zeros(kernel_shape, dtype=self.dtype) else: - kernel = self.kernel[...] - kernel = jnp.asarray(kernel, self.dtype) + kernel_val = self.kernel.value + if kernel_val is not None: + kernel = self.kernel[...] + kernel = jnp.asarray(kernel, self.dtype) + else: + kernel = None + + if kernel is not None: + contract_ind = tuple(range(0, len(norm_axis))) + output_sharding = ( + create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) + if self.shard_mode == ShardMode.EXPLICIT + else None + ) + output = linears._compute_dot_general_nnx( + inputs, + kernel, + norm_axis, + contract_ind, + self.matmul_precision, + self.quant_dot_general, + _initializing, + out_sharding=output_sharding, + ) + else: + # If kernel is missing (e.g. masked in pipeline), return zeros. + out_shape = inputs.shape[:-1] + self.out_features_shape + output = jnp.zeros(out_shape, dtype=self.dtype) - contract_ind = tuple(range(0, len(norm_axis))) - output_sharding = ( - create_sharding(self.mesh, ("activation_batch_no_exp_moe", "activation_length_no_exp_moe", None)) - if self.shard_mode == ShardMode.EXPLICIT - else None - ) - output = linears._compute_dot_general_nnx( - inputs, - kernel, - norm_axis, - contract_ind, - self.matmul_precision, - self.quant_dot_general, - _initializing, - out_sharding=output_sharding, - ) pre_bias_logits = None if self.score_func: @@ -300,8 +310,10 @@ def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax. pre_bias_logits = output if self.use_bias: - bias = jnp.asarray(self.bias[...], self.dtype) - output += bias + bias_val = self.bias.value + if bias_val is not None: + bias = jnp.asarray(self.bias[...], self.dtype) + output += bias return output, pre_bias_logits @@ -2026,40 +2038,47 @@ def __call__( routing_inputs = inputs if gate_inputs is None else gate_inputs.astype(gate_dtype) gate_logits, pre_bias_logits = self.gate(routing_inputs) - w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) - w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - wo_kernel = jnp.asarray(self.wo[...], self.dtype) + if self.wi_0.value is not None: + w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) + w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) + wo_kernel = jnp.asarray(self.wo[...], self.dtype) - if self.per_expert_scale is not None: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] - - if cfg.mlp_bias: - w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype) - w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype) - wo_bias = jnp.asarray(self.wo_bias[...], self.dtype) - else: - w0_bias, w1_bias, wo_bias = None, None, None + if self.per_expert_scale is not None: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] - if cfg.sparse_matmul: - if quantizations.in_serve_mode(self.quant): - w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( - inputs, - gate_logits, - pre_bias_logits, - w0_kernel, - w1_kernel, - wo_kernel, - w0_bias, - w1_bias, - wo_bias, + if cfg.mlp_bias: + w0_bias = jnp.asarray(self.wi_0_bias[...], self.dtype) + w1_bias = jnp.asarray(self.wi_1_bias[...], self.dtype) + wo_bias = jnp.asarray(self.wo_bias[...], self.dtype) + else: + w0_bias, w1_bias, wo_bias = None, None, None + + if cfg.sparse_matmul: + if quantizations.in_serve_mode(self.quant): + w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( + inputs, + gate_logits, + pre_bias_logits, + w0_kernel, + w1_kernel, + wo_kernel, + w0_bias, + w1_bias, + wo_bias, + ) + output, lb_loss, bias_updates = self.sparse_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias + ) + else: + output, lb_loss, bias_updates = self.dense_matmul( + inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias ) - output, lb_loss, bias_updates = self.sparse_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias - ) else: - output, lb_loss, bias_updates = self.dense_matmul( - inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias - ) + # If kernels are missing (e.g. masked in pipeline), return zeros. + output = jnp.zeros_like(inputs) + lb_loss = None + bias_updates = None + return output, lb_loss, bias_updates diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..a0c41d2b46 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -23,6 +23,7 @@ import jax import jax.numpy as jnp + from flax import linen as nn from flax import nnx from flax.nnx import wrappers as nnx_wrappers @@ -63,6 +64,8 @@ from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding from maxtext.utils.sharding import create_sharding +from maxtext.layers.pipeline import create_nnx_pipeline + # ------------------------------------------------------------------------------ # The network: Decoder Definitions @@ -220,7 +223,7 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): """Process deepstack visual embeddings by adding them to hidden states at visual token positions. Args: - hidden_states: [batch, seq_len, hidden_dim] decoder hidden states + hidden_states:[batch, seq_len, hidden_dim] decoder hidden states bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer @@ -235,12 +238,91 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): # Gather visual tokens: for each position, get the corresponding visual token batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1] visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden] - # Only add where mask is True: hidden_states += visual_embeds * mask hidden_states = hidden_states + visual_embeds_scattered * mask_expanded return hidden_states +class NNXSequentialPipelineStage(nnx.Module): + """Sequential unscanned series of decoder layers formatted for a single pipeline stage.""" + + def __init__( + self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs + ): + self.config = config + self.scan_layers = config.scan_layers + self.num_layers = num_layers + # Dynamically assign layers with explicit string names to ensure correct PyTree paths (layers_0) + for i in range(num_layers): + layer = layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rngs) + setattr(self, f"layers_{i}", layer) + + def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs): + for i in range(self.num_layers): + layer = getattr(self, f"layers_{i}") + out = layer(inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs) + inputs = out[0] if isinstance(out, tuple) else out + if self.scan_layers: + return inputs, None + return inputs + + +class NNXScannedPipelineStage(nnx.Module): + """Scanned block of decoder layers formatted for a single pipeline stage.""" + + def __init__( + self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs + ): + self.config = config + + def create_layer_fn(rng): + return layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rng) + + # Workaround for Deepseek MTP test failure. + # TODO: Handle this properly. + try: + forked_rngs = rngs.fork(split=num_layers) + except: # pylint: disable=bare-except + forked_rngs = rngs + + out_axes = nnx.StateAxes({nnx.Param: config.param_scan_axis, ...: 0}) + self.scanned_layers = nnx.vmap( + create_layer_fn, + in_axes=0, + out_axes=out_axes, + axis_name="layers_per_stage", + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + )(forked_rngs) + + def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs): + graphdef, params, state = nnx.split(self.scanned_layers, nnx.Param, ...) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + def layer_fn(carry, scanned_vars): + current_params, current_state = scanned_vars + layer = nnx.merge(graphdef, current_params, current_state) + layer_out = layer(carry, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs) + new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out + nnx.pop(layer, nnx.Intermediate) + return new_carry, nnx.state(layer) + + final_carry, scanned_state = jax.lax.scan(layer_fn, inputs, (params, state)) + + if scan_axis != 0: + scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) + scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) + scanned_state = nnx.State.merge(scanned_params, scanned_other) + + self.scanned_layers = nnx.merge(graphdef, scanned_state) + + if self.config.scan_layers: + return final_carry, None + return final_carry + + class NNXDecoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.""" @@ -301,78 +383,144 @@ def __init__( self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 - if self.config.scan_layers: - if self.is_deepseek: - assert len(decoder_block_classes) == 2 - dense_cls, moe_cls = decoder_block_classes - - num_dense = config.first_num_dense_layers - self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) - - num_moe = config.num_decoder_layers - config.first_num_dense_layers + if config.using_pipeline_parallelism: - self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) - elif self.is_gemma3: - attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) - scan_length = config.num_decoder_layers // attention_pattern_length - num_remaining_layers = config.num_decoder_layers % attention_pattern_length - layer_kwargs = {"num_of_layers": attention_pattern_length} + def stage_factory(rngs): + return self._get_pipeline_stage_module(decoder_block_classes, rngs) - rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - - RemattedGemma3Block = gemma3.Gemma3ScannableBlock - - if scan_length > 0: - self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) - self.layers_remainder = RemattedGemma3Block( - config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs - ) # pytype: disable=wrong-keyword-args - else: - layer_cls = decoder_block_classes[0] - num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if config.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - - self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) - else: - self.layers = nnx.List([]) + self.pipeline_module = create_nnx_pipeline( + config=config, + stage_factory=stage_factory, + mesh=mesh, + remat_policy=self.get_remat_policy(), + rngs=rngs, + ) if self.is_deepseek: + assert len(decoder_block_classes) == 2 dense_cls, moe_cls = decoder_block_classes - for i in range(config.first_num_dense_layers): - self._create_and_register_layer(dense_cls, rngs, "dense_layer", i) - for i in range(config.num_decoder_layers - config.first_num_dense_layers): - self._create_and_register_layer(moe_cls, rngs, "moe_layer", i) + if config.scan_layers: + self.dense_layers = self._create_scanned_layers( + dense_cls, length=config.first_num_dense_layers, metadata_axis_name="dense_layers", rngs=rngs + ) + num_moe_outside = (config.num_decoder_layers - config.first_num_dense_layers) - config.pipeline_parallel_layers + if num_moe_outside > 0: + self.moe_layers_outside_pipeline = self._create_scanned_layers( + moe_cls, length=num_moe_outside, metadata_axis_name="moe_layers", rngs=rngs + ) + else: + self.num_dense_layers = config.first_num_dense_layers + for i in range(self.num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layers", i) + + self.num_moe_outside_pipeline = ( + config.num_decoder_layers - config.first_num_dense_layers + ) - config.pipeline_parallel_layers + if self.num_moe_outside_pipeline > 0: + for i in range(self.num_moe_outside_pipeline): + self._create_and_register_layer(moe_cls, rngs, "moe_layers_outside_pipeline", i) + else: - layer_cls = decoder_block_classes[0] + remaining_layers = config.num_decoder_layers - config.pipeline_parallel_layers + if remaining_layers > 0: + base_cls = decoder_block_classes[0] + if config.scan_layers: + self.layers_outside_pipeline = self._create_scanned_layers( + base_cls, length=remaining_layers, metadata_axis_name="layers_outside_pipeline", rngs=rngs + ) + self.num_layers_outside_pipeline = remaining_layers + else: + self.num_layers_outside_pipeline = remaining_layers + for i in range(self.num_layers_outside_pipeline): + self._create_and_register_layer(base_cls, rngs, "layers_outside_pipeline", i) - for lyr in range(config.num_decoder_layers): + else: + # Setup for Standard Non-Pipeline Execution + if self.config.scan_layers: + if self.is_deepseek: + assert len(decoder_block_classes) == 2 + dense_cls, moe_cls = decoder_block_classes + self.dense_layers = self._create_scanned_layers( + dense_cls, length=config.first_num_dense_layers, metadata_axis_name="dense_layers", rngs=rngs + ) + num_moe = config.num_decoder_layers - config.first_num_dense_layers + self.moe_layers = self._create_scanned_layers( + moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs + ) + elif self.is_gemma3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + layer_kwargs = {"num_of_layers": attention_pattern_length} + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + RemattedGemma3Block = gemma3.Gemma3ScannableBlock + if scan_length > 0: + self.layers = self._create_scanned_layers( + RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + self.layers_remainder = RemattedGemma3Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) + else: + layer_cls = decoder_block_classes[0] + num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) layer_kwargs = {} - if config.decoder_block == DecoderBlockType.GEMMA3: - layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} - elif config.decoder_block == DecoderBlockType.LLAMA4: + if config.decoder_block == DecoderBlockType.LLAMA4: layer_kwargs = { - "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), - "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: - layer_kwargs = {"layer_idx": lyr} - elif config.decoder_block == DecoderBlockType.GPT_OSS: - layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - elif config.decoder_block == DecoderBlockType.OLMO3: - layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + self.layers = self._create_scanned_layers( + layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + else: + if self.is_deepseek: + dense_cls, moe_cls = decoder_block_classes + self.num_dense_layers = config.first_num_dense_layers + for i in range(self.num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layers", i) + self.num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers + for i in range(self.num_moe_layers): + self._create_and_register_layer(moe_cls, rngs, "moe_layers", i) + else: + layer_cls = decoder_block_classes[0] + self.num_layers = config.num_decoder_layers + for lyr in range(self.num_layers): + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + } + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + layer_kwargs = {"layer_idx": lyr} + elif config.decoder_block == DecoderBlockType.GPT_OSS: + layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.OLMO3: + layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + + def _get_pipeline_stage_module(self, decoder_blocks, rngs): + """Retrieves the wrapper module formatted for single pipeline stage execution.""" + cfg = self.config + base_stage_cls = decoder_blocks[1] if self.is_deepseek else decoder_blocks[0] - self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + if cfg.num_layers_per_pipeline_stage == 1: + return self._create_single_layer(base_stage_cls, rngs) + elif cfg.scan_layers_per_stage or cfg.scan_layers: + return NNXScannedPipelineStage( + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs + ) + return NNXSequentialPipelineStage( + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs + ) def _create_and_register_layer(self, layer_cls, rngs, base_name, i, **layer_kwargs): attr_name = f"{base_name}_{i}" layer = self._create_single_layer(layer_cls, rngs, **layer_kwargs) setattr(self, attr_name, layer) - self.layers.append(layer) def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): """Helper to create a single layer (Linen or NNX).""" @@ -386,38 +534,37 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): ) return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) - def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): + def _create_scanned_layers( + self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs + ): """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + if length == 0: + return nnx.List([]) def create_layer_fn(rng): - layer = decoder_layer_class( + return decoder_layer_class( config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs ) - return layer - # Workaround for Deepseek MTP test failure. # TODO: Handle this properly. try: forked_rngs = rngs.fork(split=length) - except: # pylint: disable=bare-except - pass + forked_rngs = rngs out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) layers_vmapped = nnx.vmap( create_layer_fn, in_axes=0, out_axes=out_axes, - axis_name="layers", - transform_metadata={nnx.PARTITION_NAME: "layers"}, + axis_name=metadata_axis_name, + transform_metadata={nnx.PARTITION_NAME: metadata_axis_name}, )(forked_rngs) - return layers_vmapped def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" - graphdef, state = nnx.split(layer) def pure_layer_fn(state_in, y_in): @@ -425,14 +572,17 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + if not self._uses_linen_fp8_ops(): + pure_layer_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = pure_layer_fn(state, y) nnx.update(layer, new_state) - return out def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): """Runs the layer stack using nnx.scan.""" + if length == 0: + _, empty_state = nnx.split(layers) + return x_in, empty_state policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) graphdef, params, state = nnx.split( @@ -442,47 +592,232 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs) scan_axis = self.config.param_scan_axis if scan_axis != 0: # Move scan_axis to 0 so scan can iterate over it - params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + def move_axis_abstract_safe(x): + if isinstance(x, jax.ShapeDtypeStruct): + # Manually calculate new shape for abstract tracers + new_shape = list(x.shape) + ax = scan_axis if scan_axis >= 0 else len(new_shape) + scan_axis + val = new_shape.pop(ax) + new_shape.insert(0, val) + return jax.ShapeDtypeStruct(tuple(new_shape), x.dtype) + return jnp.moveaxis(x, scan_axis, 0) + + params = jax.tree.map(move_axis_abstract_safe, params) layer_cls = layers.__class__ sig = inspect.signature(layer_cls.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars - if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) + nnx.pop(layer, nnx.Intermediate) new_current_state = nnx.state(layer) return new_carry, new_current_state - layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + if self._uses_linen_fp8_ops(): + # jax.lax.scan is incompatible with Linen fp8 ops: put_variable in setup() stores + # scan-level tracers as Python attributes on the Linen module, causing a tracer leak + # across the scan boundary. Fall back to a Python loop instead. + x = x_in + for i in range(length): + params_i = jax.tree.map(lambda p, _i=i: p[_i], params) + state_i = jax.tree.map(lambda s, _i=i: s[_i], state) + layer = nnx.merge(graphdef, params_i, state_i) + layer_out = layer(x, *args, **valid_kwargs) + x = layer_out[0] if isinstance(layer_out, tuple) else layer_out + nnx.pop(layer, nnx.Intermediate) + if scan_axis != 0: + params = jax.tree.map(lambda p: jnp.moveaxis(p, 0, scan_axis), params) + return x, nnx.State.merge(params, state) + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) - scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) + # Only move the axis back on the params, NOT the mutables! + def move_axis_back_abstract_safe(x): + if isinstance(x, jax.ShapeDtypeStruct): + new_shape = list(x.shape) + val = new_shape.pop(0) + ax = scan_axis if scan_axis >= 0 else len(new_shape) + 1 + scan_axis + new_shape.insert(ax, val) + return jax.ShapeDtypeStruct(tuple(new_shape), x.dtype) + return jnp.moveaxis(x, 0, scan_axis) + + params = jax.tree.map(move_axis_back_abstract_safe, params) + + final_state = nnx.State.merge(params, scanned_state) + # Skip direct mutation of 'layers' during compilation to avoid TraceContextError. + # The caller will handle the final state update. + return final_carry, final_state + + def _apply_interleaved_scanned_layers( + self, layers, y, layer_args, layer_kwargs, start_idx, end_idx, engram_indices, decoder_input_tokens + ): + """Applies a mix of scanned standard layers and unscanned Engram layers efficiently using native NNX state slicing.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + graphdef, params, mutables = nnx.split(layers, nnx.Param, ...) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + max_logging.log(f"nnx_decoders: Moving param scan_axis from {scan_axis} to 0 for interleaved scan.") + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + def get_chunk(pytree, start, end): + return jax.tree.map(lambda x: x[start:end], pytree) + + updated_mutables_chunks = [] + current_idx = start_idx + + while current_idx < end_idx: + if current_idx in engram_indices: + # Single engram layer execution + eng_params = get_chunk(params, current_idx, current_idx + 1) + eng_mutables = get_chunk(mutables, current_idx, current_idx + 1) + + # Squeeze the vmapped 'layers' dimension out for isolated execution + eng_params = jax.tree.map(lambda x: jnp.squeeze(x, axis=0), eng_params) + eng_mutables = jax.tree.map(lambda x: jnp.squeeze(x, axis=0), eng_mutables) + + if self.config.parameter_memory_host_offload: + eng_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), eng_params) + + layer = nnx.merge(graphdef, eng_params, eng_mutables) + kwargs_with_tokens = {**layer_kwargs, "decoder_input_tokens": decoder_input_tokens, "layer_idx": current_idx} - return final_carry, nnx.merge(graphdef, scanned_state) + sig = inspect.signature(layer.__call__) + valid_kwargs = {k: v for k, v in kwargs_with_tokens.items() if k in sig.parameters or "kwargs" in sig.parameters} + + layer_out = layer(y, *layer_args, **valid_kwargs) + y = layer_out[0] if isinstance(layer_out, tuple) else layer_out + + nnx.pop(layer, nnx.Intermediate) + _, _, new_eng_mutables = nnx.split(layer, nnx.Param, ...) + new_eng_mutables = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), new_eng_mutables) + updated_mutables_chunks.append(new_eng_mutables) + current_idx += 1 + else: + # Scan a continuous chunk of non-engram layers + next_engrams = [l for l in engram_indices if l > current_idx] + if next_engrams: + min_next_engram = min(next_engrams) + next_boundary = min(end_idx, min_next_engram) + else: + next_boundary = end_idx + + chunk_params = get_chunk(params, current_idx, next_boundary) + chunk_mutables = get_chunk(mutables, current_idx, next_boundary) + + sig = inspect.signature(layers.__call__) + valid_kwargs = {k: v for k, v in layer_kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + def layer_fn(carry, scanned_vars): + curr_p, curr_m = scanned_vars + if self.config.parameter_memory_host_offload: + curr_p = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), curr_p) + l = nnx.merge(graphdef, curr_p, curr_m) + l_out = l(carry, *layer_args, **valid_kwargs) + n_carry = l_out[0] if isinstance(l_out, tuple) else l_out + nnx.pop(l, nnx.Intermediate) + _, _, n_mut = nnx.split(l, nnx.Param, ...) + return n_carry, n_mut + + if not self._uses_linen_fp8_ops(): + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + y, new_chunk_mutables = jax.lax.scan(layer_fn, y, (chunk_params, chunk_mutables)) + updated_mutables_chunks.append(new_chunk_mutables) + current_idx = next_boundary + + if updated_mutables_chunks: + final_mutables = jax.tree.map(lambda *chunks: jnp.concatenate(chunks, axis=0), *updated_mutables_chunks) + else: + final_mutables = mutables + + if scan_axis != 0: + max_logging.log(f"nnx_decoders: Moving param scan_axis 0 back to {scan_axis} for interleaved scan.") + # Only move the axis back on params! + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) + + final_state = nnx.State.merge(params, final_mutables) + nnx.update(layers, final_state) + return y, layers + + def _run_unscanned_layers_loop( + self, + base_name, + num_layers, + y, + layer_args, + layer_kwargs, + kv_caches=None, + deepstack_visual_embeds=None, + bidirectional_mask=None, + layer_idx_offset=0, + decoder_input_tokens=None, + ): + """DRY Helper for looping unscanned lists of layers while correctly handling remat, state, engrams, and KV cache.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + + def pure_layer_fn(graphdef, state_in, y_in, kv_in, dynamic_kwargs): + merged_layer = nnx.merge(graphdef, state_in) + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **dynamic_kwargs) + return out_y, out_kv, nnx.state(merged_layer) + + checkpointed_fn = ( + pure_layer_fn + if self._uses_linen_fp8_ops() + else jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + ) + + for lyr in range(num_layers): + attr_name = f"{base_name}_{lyr}" + layer = getattr(self, attr_name) + graphdef, state = nnx.split(layer) + global_lyr = layer_idx_offset + lyr + + # Prepare dynamic KV Cache unwrapping + kv_cache = None + if kv_caches is not None and self.config.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_cache = kv_caches[global_lyr] + elif kv_caches is not None and self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (global_lyr + 1) % self.config.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][global_lyr], kv_caches["value_cache"][global_lyr]) + + # Prepare dynamic Kwargs (Engrams, Layer ID) + current_kwargs = dict(layer_kwargs) + if self.config.engram_layers: + current_kwargs["decoder_input_tokens"] = decoder_input_tokens + if self.config.decoder_block == DecoderBlockType.DEEPSEEK: + current_kwargs["layer_idx"] = global_lyr + + y, returned_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache, current_kwargs) + # Re-merge the state back to the explicit attribute to prevent cross-boundary TraceContextErrors + setattr(self, attr_name, nnx.merge(graphdef, new_state)) + + # Write updated KV Cache back properly + if kv_caches is not None and returned_cache is not None: + if self.config.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_caches[global_lyr] = returned_cache + elif (global_lyr + 1) % self.config.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][global_lyr] = returned_cache[0] + kv_caches["value_cache"][global_lyr] = returned_cache[1] + + if deepstack_visual_embeds is not None and global_lyr < len(deepstack_visual_embeds): + visual_embeds = deepstack_visual_embeds[global_lyr] + if bidirectional_mask is not None and visual_embeds is not None: + y = deepstack_process(y, bidirectional_mask, visual_embeds) + + return y def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -518,7 +853,6 @@ def get_deepseek(): if cfg.decoder_block not in layer_map: raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") - return layer_map[cfg.decoder_block] def minimal_policy(self, with_context=False, with_quantization=False): @@ -573,37 +907,18 @@ def get_remat_policy(self): policy = self.minimal_policy(with_context=True, with_quantization=True) elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "context", - "out_proj", + "query_proj", "value_proj", "key_proj", "qkv_proj", "context", "out_proj" ) elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwo", + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwo" ) elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj" ) elif cfg.remat_policy == "save_qkv_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - ) + policy = jax.checkpoint_policies.save_only_these_names("query_proj", "value_proj", "key_proj", "qkv_proj") elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], @@ -612,7 +927,6 @@ def get_remat_policy(self): offload_dst="pinned_host", ) elif cfg.remat_policy == "minimal_offloaded": - # offload all except context policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[ @@ -640,11 +954,14 @@ def get_remat_policy(self): policy = jax.checkpoint_policies.save_only_these_names("out_proj") else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None return policy + def _uses_linen_fp8_ops(self) -> bool: + """Returns True if the quantization mode uses Linen fp8 ops incompatible with jax.checkpoint.""" + return self.config.quantization in ("fp8_gpu", "fp8_nanoo") + def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): - """get normalization layer (return type inherits from nn.Module)""" + """Helper to retrieve the correct normalization layer class based on config, partially applied with common arguments.""" if self.config.decoder_block in ( DecoderBlockType.DEFAULT, DecoderBlockType.LLAMA2, @@ -687,28 +1004,25 @@ def _apply_embedding( audio_embeddings=None, audio_masks=None, ): - """Applies token and positional embeddings to the input tokens.""" + """Applies token embedding, adds positional embedding, and merges multimodal embeddings if provided.""" cfg = self.config - y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - # Merge the image embeddings with the text embeddings for multimodal models if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in [ + if cfg.model_name in { "gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", - ]: + }: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=image_embeddings, mask=bidirectional_mask, token_masks=image_masks, ) - # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed else: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") @@ -736,7 +1050,6 @@ def _apply_embedding( def apply_output_head(self, shared_embedding, y, deterministic, model_mode): """Applies final normalization and projects hidden states to logits.""" - cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) @@ -779,115 +1092,6 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): return logits - def _build_linen_params(self, moe_stack: nnx.Module) -> dict: - """ - Bridges NNX to Linen by creating a dictionary that mimics the exact variable - structure expected by `deepseek_batchsplit.fetch_weights`. - """ - - return { - "pre_self_attention_layer_norm": { - "scale": moe_stack.pre_self_attention_layer_norm.scale, - }, - "post_self_attention_layer_norm": { - "scale": moe_stack.post_self_attention_layer_norm.scale, - }, - "self_attention": { - "wq_a": {"kernel": moe_stack.self_attention.wq_a.kernel}, - "wq_b": {"kernel": moe_stack.self_attention.wq_b.kernel}, - "q_norm": {"scale": moe_stack.self_attention.q_norm.scale}, - "wkv_a": {"kernel": moe_stack.self_attention.wkv_a.kernel}, - "wkv_b": {"kernel": moe_stack.self_attention.wkv_b.kernel}, - "kv_norm": {"scale": moe_stack.self_attention.kv_norm.scale}, - "out": {"kernel": moe_stack.self_attention.out.kernel}, - }, - "DeepSeekMoeBlock_0": { - "MoeBlock_0": { - "gate": { - "kernel": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel, - "bias": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias, - }, - "wi_0": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_0, - "wi_1": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_1, - "wo": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wo, - }, - "shared_experts": { - "wi_0": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel}, - "wi_1": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel}, - "wo": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wo.kernel}, - }, - }, - } - - def _find_next_boundary(self, current_idx, end_idx, engram_indices): - """Finds the next index boundary, either the next Engram layer index or the overall end index.""" - next_engrams = [l for l in engram_indices if l > current_idx] - if next_engrams: - return min(end_idx, *next_engrams) - return end_idx - - def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs): - """Applies a single, unscanned Engram layer by dynamically slicing the NNX state.""" - graphdef, state = nnx.split(layer_stack) - - # Slice the parameters for the current index (assuming scan axis is 0) - sliced_state = jax.tree.map(lambda x: x[current_idx], state) - single_layer = nnx.merge(graphdef, sliced_state) - - # Run the single layer - out = single_layer( - y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {}) - ) - y = out[0] if isinstance(out, tuple) else out - - # Re-merge the updated state back into the specific slice of the stack - new_single_state = nnx.state(single_layer) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0), - state, - new_single_state, - ) - nnx.update(layer_stack, updated_state) - - return y - - def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs): - """Applies a contiguous chunk of layers using scan over a state slice.""" - scan_length = next_boundary - current_idx - if scan_length > 0: - graphdef, state = nnx.split(layer_stack) - - # Slice the chunk state - chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state) - chunk_stack = nnx.merge(graphdef, chunk_state) - - # Apply sequentially - y, chunk_stack = self._apply_layers_sequentially( - chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {}) - ) - - # Update the original stack state - new_chunk_state = nnx.state(chunk_stack) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state - ) - nnx.update(layer_stack, updated_state) - - return y - - def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs): - """Applies a mix of scanned standard layers and unscanned Engram layers.""" - current_idx = start_idx - while current_idx < end_idx: - if current_idx in engram_indices: - y = self._apply_single_engram_layer(y, current_idx, layer_stack, *args, **kwargs) - current_idx += 1 - else: - next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices) - y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_stack, *args, **kwargs) - current_idx = next_boundary - return y - def __call__( self, shared_embedding: Any, @@ -907,11 +1111,17 @@ def __call__( audio_embeddings: None | jnp.ndarray = None, audio_masks: None | jnp.ndarray = None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input=None, ): cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] - policy = self.get_remat_policy() + if multimodal_input is not None: + image_embeddings = multimodal_input.image_embeddings + bidirectional_mask = multimodal_input.bidirectional_mask + image_masks = multimodal_input.image_masks + audio_embeddings = multimodal_input.audio_embeddings + audio_masks = multimodal_input.audio_masks # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( @@ -940,134 +1150,271 @@ def __call__( if attention_metadata is not None: layer_kwargs["attention_metadata"] = attention_metadata + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK and cfg.scan_layers: + layer_kwargs = {"previous_chunk": previous_chunk, "page_state": page_state, "slot": slot} - if cfg.scan_layers: - if self.is_deepseek: - layer_kwargs = { - "previous_chunk": previous_chunk, - "page_state": page_state, - "slot": slot, - } - - if cfg.engram_layers: - common_kwargs = { - "layer_kwargs": layer_kwargs, - "decoder_input_tokens": decoder_input_tokens, - } - - y = self._apply_interleaved_scanned_layers( - y, self.dense_layers, 0, cfg.first_num_dense_layers, cfg.engram_layers, *layer_args, **common_kwargs - ) - - y = self._apply_interleaved_scanned_layers( - y, - self.moe_layer, - 0, - (cfg.num_decoder_layers - cfg.first_num_dense_layers), - [e - cfg.first_num_dense_layers for e in cfg.engram_layers], - *layer_args, - **common_kwargs, - ) - else: - y, self.dense_layers = self._apply_layers_sequentially( - self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs - ) - - num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers - - if cfg.use_batch_split_schedule: - policy = self.get_remat_policy() - - mock_params = self._build_linen_params(self.moe_layer) + # ------------------------------------------------------------------------- + # Execution Routing (Pipeline vs Direct) + # ------------------------------------------------------------------------- + if cfg.using_pipeline_parallelism: + logical_partition_spec = self.pipeline_module.get_weight_sharding() if cfg.pipeline_fsdp_ag_once else None - y = deepseek_batchsplit.scan_batch_split_layers( - y, - mock_params, - decoder_positions, - decoder_segment_ids, - model_mode=model_mode, - mesh=self.mesh, - quant=self.quant, - cfg=cfg, - policy=policy, - ) + if self.is_deepseek: + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.scan_layers: + if cfg.engram_layers: + y, self.dense_layers = self._apply_interleaved_scanned_layers( + self.dense_layers, + y, + layer_args, + layer_kwargs, + start_idx=0, + end_idx=cfg.first_num_dense_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + if hasattr(self, "moe_layers_outside_pipeline"): + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + y, self.moe_layers_outside_pipeline = self._apply_interleaved_scanned_layers( + self.moe_layers_outside_pipeline, + y, + layer_args, + layer_kwargs, + start_idx=cfg.first_num_dense_layers, + end_idx=cfg.first_num_dense_layers + num_moe_outside, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y, new_state = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + self._trace_safe_update(self.dense_layers, new_state) + if hasattr(self, "moe_layers_outside_pipeline"): + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + y, new_state = self._apply_layers_sequentially( + self.moe_layers_outside_pipeline, y, *layer_args, length=num_moe_outside, **layer_kwargs + ) + self._trace_safe_update(self.moe_layers_outside_pipeline, new_state) else: - y, self.moe_layer = self._apply_layers_sequentially( - self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs + y = self._run_unscanned_layers_loop( + base_name="dense_layers", + num_layers=self.num_dense_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, ) - elif self.is_gemma3: - y = self._apply_gemma3_scanned_blocks( + if hasattr(self, "num_moe_outside_pipeline") and self.num_moe_outside_pipeline > 0: + y = self._run_unscanned_layers_loop( + base_name="moe_layers_outside_pipeline", + num_layers=self.num_moe_outside_pipeline, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.first_num_dense_layers, + decoder_input_tokens=decoder_input_tokens, + ) + + y = self.pipeline_module( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, + logical_partition_spec=logical_partition_spec, ) - else: - scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) - else: - prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) - - # Hoisted function to preserve XLA cache ID - def pure_layer_fn(graphdef, state_in, y_in, kv_in): - if cfg.parameter_memory_host_offload: - state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) - - merged_layer = nnx.merge(graphdef, state_in) - out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) - return out_y, out_kv, nnx.state(merged_layer) - - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - - for lyr, layer in enumerate(self.layers): - graphdef, state = nnx.split(layer) - kv_cache = kv_caches[lyr] if kv_caches is not None else None - - input_tokens = decoder_input_tokens if cfg.engram_layers else None - if input_tokens is not None: - layer_kwargs["decoder_input_tokens"] = input_tokens - - y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) - nnx.update(layer, new_state) + else: + # Standard Pipeline Run + y = self.pipeline_module( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + logical_partition_spec=logical_partition_spec, + ) - if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + # Remaining standard layers + if hasattr(self, "num_layers_outside_pipeline") or hasattr(self, "layers_outside_pipeline"): + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.scan_layers: + y, new_state = self._apply_layers_sequentially( + self.layers_outside_pipeline, + y, + *layer_args, + length=self.num_layers_outside_pipeline, + **layer_kwargs, + ) + self._trace_safe_update(self.layers_outside_pipeline, new_state) + else: + y = self._run_unscanned_layers_loop( + base_name="layers_outside_pipeline", + num_layers=self.num_layers_outside_pipeline, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.pipeline_parallel_layers, + decoder_input_tokens=decoder_input_tokens, + ) - if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): - visual_embeds = deepstack_visual_embeds[lyr] - if bidirectional_mask is not None and visual_embeds is not None: - y = deepstack_process(y, bidirectional_mask, visual_embeds) + else: + # Non-Pipeline Run + if cfg.scan_layers: + if self.is_deepseek: + if cfg.engram_layers: + y, self.dense_layers = self._apply_interleaved_scanned_layers( + self.dense_layers, + y, + layer_args, + layer_kwargs, + start_idx=0, + end_idx=cfg.first_num_dense_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, self.moe_layers = self._apply_interleaved_scanned_layers( + self.moe_layers, + y, + layer_args, + layer_kwargs, + start_idx=cfg.first_num_dense_layers, + end_idx=cfg.num_decoder_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y, new_state = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + self._trace_safe_update(self.dense_layers, new_state) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + + # Use raw deepseek_batchsplit logic for MoE scanned layers to minimize VRAM overhead + layer_is_initializing = self.quant is not None and len(nnx.state(self.moe_layers, "aqt")) == 0 + if cfg.use_batch_split_schedule and not layer_is_initializing: + raw_weights = nnx.to_pure_dict(nnx.state(self.moe_layers, nnx.Param)) + y = deepseek_batchsplit.scan_batch_split_layers( + y, + raw_weights, + decoder_positions, + decoder_segment_ids, + model_mode=model_mode, + mesh=self.mesh, + quant=self.quant, + cfg=cfg, + policy=self.get_remat_policy(), + ) + else: + y, new_state = self._apply_layers_sequentially( + self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs + ) + self._trace_safe_update(self.moe_layers, new_state) + + elif self.is_gemma3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + y, new_state = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=cfg.num_decoder_layers, **layer_kwargs + ) + self._trace_safe_update(self.layers, new_state) + else: + if self.is_deepseek: + y = self._run_unscanned_layers_loop( + base_name="dense_layers", + num_layers=self.num_dense_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, + ) + y = self._run_unscanned_layers_loop( + base_name="moe_layers", + num_layers=self.num_moe_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.first_num_dense_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y = self._run_unscanned_layers_loop( + base_name="layers", + num_layers=self.num_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, + ) assert isinstance(y, jax.Array) - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. if cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) + hidden_state = mhc_reduce(y) else: hidden_state = y # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. if cfg.attention == "vllm_rpa": + if not cfg.logits_via_embedding and hasattr(self, "logits_dense"): + if self.quant is not None and len(nnx.state(self.logits_dense, "aqt")) == 0: + _ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) - else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) return logits, hidden_state, kv_caches + def _trace_safe_update(self, layer, new_state): + """Updates the layer state only if not currently in a tracing context where mutation is forbidden.""" + # Check if we are in an abstract tracing context (like jax.eval_shape) + # where mutation of variables from outer scopes is disallowed. + is_tracing = any(isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(new_state)) + if not is_tracing: + nnx.update(layer, new_state) + def _apply_gemma3_scanned_blocks( self, y, @@ -1093,7 +1440,8 @@ def _apply_gemma3_scanned_blocks( # Apply the main scan over the full blocks if scan_length > 0: - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + y, new_state = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + self._trace_safe_update(self.layers, new_state) # Apply any remaining layers that did not fit into a full scanned block num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length @@ -1109,10 +1457,9 @@ def pure_gemma_fn(graphdef, state_in, y_in): return out_y, nnx.state(merged_layer) checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) - graphdef, state = nnx.split(self.layers_remainder) y, new_state = checkpointed_gemma_fn(graphdef, state, y) - nnx.update(self.layers_remainder, new_state) + self.layers_remainder = nnx.merge(graphdef, new_state) return y diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 6a1aba8470..aa01748fd5 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -170,6 +171,23 @@ def current_linen_module() -> linen.Module | None: return None +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -467,6 +485,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index 3bce30d44e..a3d23bb1d8 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -83,7 +83,7 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = self.scale.value # Move scale to device if parameter offloading is enabled - if self.parameter_memory_host_offload: + if scale is not None and self.parameter_memory_host_offload: max_logging.log("normalizations.py: Moving scale parameter to device") scale = jax.device_put(scale, max_utils.device_space()) @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 1b130f1888..b8417d9811 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -26,7 +26,7 @@ from flax.core import meta from flax import linen as nn -from flax.linen.spmd import LogicallyPartitioned +from flax import nnx from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode from maxtext.utils.sharding import ( @@ -37,25 +37,22 @@ logical_to_mesh, ) from maxtext.utils import pipeline_utils +from maxtext.utils import max_logging -class PipelineBase(nn.Module): - """Base module that implements shared pipelining logic across stages.""" +class PipelineSharedMixin: + """Contains pure JAX and mathematical utilities shared identically by Linen and NNX.""" - config: Config - layers: nn.Module - mesh: Mesh - remat_policy: Any = None - - def setup(self): + def _setup_pipeline_attributes(self): """Initializes the configuration, calculating num_stages, delay, axes, and partition specs.""" self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches - microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages - self.microbatches_per_stage = microbatches_per_stage + self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.use_circ_storage = self.need_circ_storage() + self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: self.batch_axis_name = "activation_batch_no_exp" self.seq_len_axis_name = "activation_length" @@ -63,8 +60,6 @@ def setup(self): self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" - self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None - self.stages_in_logical = ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed") self.stages_in_spec = logical_to_mesh_axes(self.stages_in_logical, self.mesh, rules=self.config.logical_axis_rules) self.stages_in_sharding = ( @@ -176,8 +171,7 @@ def select_state_or_input(first_stage_in, shift): # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) - stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) - return stages_in + return self._maybe_shard_with_logical(stages_in, self.stages_in_logical) def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and @@ -193,140 +187,10 @@ def get_pipeline_remat_policy(self): """Returns the pipeline remat policy for this pipeline.""" if self.config.remat_policy == "custom": return self.remat_policy - save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") if self.remat_policy is not None: - remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) - else: - remat_policy = save_input_policy - return remat_policy - - def get_weight_sharding(self, *init_args): - """get weight sharding function for this pipeline.""" - key = jax.random.PRNGKey(0) - keys = {"params": key, "dropout": key, "aqt": key} - weights = self.init(keys, *init_args) - - def get_partition_spec(pytree): - def _is_leaf(x): - return isinstance(x, nn.spmd.LogicallyPartitioned) - - def get_partition_spec_leaf(leaf): - return leaf.get_partition_spec() - - return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) - - partition_spec_with_extra_layer = get_partition_spec(weights) - logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} - return logical_partition_spec - - def get_vmap_func_for_init(self): - """This vmap func is used to initialize the weights only on init.""" - - def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): - return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, None, None), - spmd_axis_name=self.spmd_axis_name, - variable_axes={"params": 0, "_overwrite_with_gradient": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return vmap_func - - def get_main_vmap_func_for_iterations(self): - """ - Returns main stage function vmapped by number of stages. - This becomes a vmap over a single layer instance if body_instance is a single layer, - else a set of layers if body_instance is a set of layers. - """ - - def func_to_vmap( - body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode - ): - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, 0, None, None), - spmd_axis_name=self.spmd_axis_name, - variable_axes={"params": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return vmap_func - - def _run_weight_initialization( - self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode - ): - """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" - vmap_func = self.get_vmap_func_for_init() - - if self.config.num_pipeline_repeats > 1: - vmap_func = nn.vmap( - vmap_func, - in_axes=(0, segment_idx, position_idx, None, None), - variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0}, - split_rngs={"params": True, "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": True, - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - }, - ) - example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) - example_segmentation = ( - jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) - if example_segmentation is not None - else None - ) - example_position = ( - jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) - if example_position is not None - else None - ) - - example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) - stage_outputs = vmap_func( - self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode - ) - if self.config.scan_layers: - stage_outputs = stage_outputs[0] - if self.config.num_pipeline_repeats > 1: - stage_outputs = stage_outputs[0] - broadcasted_stage_outpus = jax.lax.broadcast( - stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] - ) - - return jnp.reshape( - broadcasted_stage_outpus, - [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], - out_sharding=self.output_sharding, - ) + return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + return save_input_policy @staticmethod def _remove_fsdp_from_physical_partition_spec(pps): @@ -353,10 +217,6 @@ def _remove_fsdp_from_physical_partition_spec(pps): return P(*new_spec) return pps - -class Pipeline(PipelineBase): - """Original Pipeline implementation.""" - def init_states(self, inputs): """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] @@ -389,6 +249,7 @@ def init_states(self, inputs): state_io = jnp.reshape( inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding ) + # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical) @@ -410,7 +271,7 @@ def init_states(self, inputs): circ_storage = None circ_storage_mover = None - init_loop_state = { + return { "state_io": state_io, "shift": shift, "circ_storage": circ_storage, @@ -418,12 +279,14 @@ def init_states(self, inputs): "loop_iteration": 0, "prev_outputs": prev_outputs, } - return init_loop_state def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at the specified dimension.""" placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED + if x.ndim == 0 or dim >= x.ndim: + # Scalar or out-of-bounds dim (e.g. repeat_ids inside vmap over stage axis). No-op. + return x if physical_partition_spec is None: dims_mapping = [placeholder] * x.ndim else: @@ -472,10 +335,9 @@ def _gather_one(x, repeat_id): stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - stage_weights = self.shard_dim_by_stages( + return self.shard_dim_by_stages( stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True ) - return stage_weights def vmap_gather(self, xs, ids, ids_dim): """Use vmap to implement a stage-wise sharded gather. @@ -492,9 +354,11 @@ def vmap_gather(self, xs, ids, ids_dim): The per-stage gathered values. The shape is xs.shape but with ids_dim size replaced with [num_stages]. """ + xs = jnp.asarray(xs) + ndim = xs.ndim def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) + idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) @@ -502,17 +366,11 @@ def _gather_one(x, i): outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) - def get_new_loop_state(self, output, loop_state): - """ - Update the various buffers given the output of the most recent iteration - * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) - * Pushing inputs up from top of state_io into first stage of shift - * Pulling outputs up from last stage of shift into bottom of state_io - * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to - right/down - * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage - * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration - * prev_outputs: is set to the current output + def advance_circular_buffers(self, output, loop_state): + """Rotates pipeline activations to the next physical device stage. + + Uses `jax.lax.ppermute` to perform cross-device ring communication, shifting + the forward activations (`state_io` and `shift`) from stage $i$ to stage $i+1$. """ old_state_io = loop_state["state_io"] old_circ_storage = loop_state["circ_storage"] @@ -522,11 +380,9 @@ def get_new_loop_state(self, output, loop_state): @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting stage_size = jax.lax.axis_size("stage") perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return arr + return jax.lax.ppermute(arr, axis_name="stage", perm=perm) @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): @@ -558,8 +414,7 @@ def _update_shift(output_in): # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped # compute/async transfers def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): - rotated = _rotate_right(circ_storage_mover_in) - rotated = jnp.expand_dims(rotated, 1) + rotated = jnp.expand_dims(_rotate_right(circ_storage_mover_in), 1) # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 offset = ( loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 @@ -602,7 +457,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) - new_loop_state = { + return { "state_io": new_state, "shift": new_shift, "circ_storage": new_circ_storage, @@ -610,7 +465,6 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): "loop_iteration": loop_iteration + 1, "prev_outputs": new_prev_outputs, } - return new_loop_state def permute_output_micro_per_stage_dim(self, output): """ @@ -626,8 +480,18 @@ def permute_output_micro_per_stage_dim(self, output): # state_io - it will land on a different index of state_io depending on the number of iterations. microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage permutation = (np.arange(self.microbatches_per_stage) + microbatch_0_idx) % self.microbatches_per_stage - output = output[:, permutation] - return output + return output[:, permutation] + + def realign_output_microbatches(self, output): + """Reorders the output tensor to reverse the circular shifts applied during execution. + + Because the pipeline operates circularly, the output microbatches are shifted + out of order by the time the final stage is completed. This rolls them back + into their original sequential layout. + """ + microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + output = jnp.roll(output, shift=-microbatch_0_idx, axis=1) + return self._maybe_shard_with_logical(output, self.state_io_logical) def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): """ @@ -640,102 +504,279 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p return self.get_current_repeat_from_stages( pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec ) - else: - return pipeline_weights + return pipeline_weights - def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): - """Fetches the weights for the current repeat from the stages.""" - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, - # only one circular entry per stage. - weights = meta.remove_axis(weights, 0, circular_metadata_params) - weights = self._remove_logically_partition(weights) + def all_gather_over_fsdp(self, variables, logical_partition_spec): + """ + all-gathers the variables over fsdp if fsdp is in the logical partition spec. + """ + if logical_partition_spec is None: + return variables + + def _gather_leaf(var, spec): + if spec is None: + return var + physical = logical_to_mesh_axes(spec, self.mesh, rules=self.config.logical_axis_rules) + no_fsdp = self._remove_fsdp_from_physical_partition_spec(physical) + sharding = NamedSharding(self.mesh, no_fsdp) + if isinstance(var, nnx.Variable): + var.value = self._maybe_shard_with_name(var.value, sharding) + return var + return self._maybe_shard_with_name(var, sharding) + + # nnx.Variable and PartitionSpec are JAX pytree nodes — treat them as leaves + # so the two trees align at the dict level. None must also be a leaf to avoid + # being treated as an empty container (0 children) vs the Variable's 1 child. + def is_leaf(x): + return isinstance(x, (nnx.Variable, P)) or x is None + + return jax.tree.map(_gather_leaf, variables, logical_partition_spec, is_leaf=is_leaf) + + def get_logical_spec_repeats_removed(self, full_logical): + """Returns a new logical spec with 'circular_repeats' removed.""" + if full_logical is None or self.config.num_pipeline_repeats == 1: + return full_logical - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) + def _remove_from_spec(spec): + if not isinstance(spec, P): + return spec + if spec and (spec[0] == "circular_repeats" or spec[0] is None): + return jax.sharding.PartitionSpec(*spec[1:]) + return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights + return jax.tree.map(_remove_from_spec, full_logical, is_leaf=lambda x: isinstance(x, P)) - def run_one_iteration( - self, - loop_state, - pipeline_weights, - positions, - segment_ids, - deterministic, - model_mode, - decoder_layer_instance, - logical_partition_spec=None, - ): - """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, - and update the loop state. - Args: - loop_state: Dictionary containing the current state of the pipeline (state_io, shift, etc.) - positions: Positional encodings. - segment_ids: Segment IDs for packed sequences. - deterministic: Boolean indicating if execution should be deterministic (e.g. for dropout). - model_mode: Current model mode (train/predict). - logical_partition_spec: Logical partition specification for weights. - """ - state_io = loop_state["state_io"] - shift = loop_state["shift"] - circ_storage = loop_state["circ_storage"] - loop_iteration = loop_state["loop_iteration"] +class PipelineBaseLinen(nn.Module, PipelineSharedMixin): + """Base module that implements shared pipelining logic across stages for Linen.""" - microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) - physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + config: Config + layers: nn.Module + mesh: Mesh + remat_policy: Any = None - stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) - stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") - stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None - stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + def setup(self): + self._setup_pipeline_attributes() - vmap_func = self.get_main_vmap_func_for_iterations() + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + key = jax.random.PRNGKey(0) + keys = {"params": key, "dropout": key, "aqt": key} + weights = self.init(keys, *init_args) - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + def get_partition_spec(pytree): + def _is_leaf(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) - def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis(weights, 0, circular_metadata_params) - weights = self._remove_logically_partition(weights) + def get_partition_spec_leaf(leaf): + return leaf.get_partition_spec() - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) + return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights + partition_spec_with_extra_layer = get_partition_spec(weights) + return {"params": partition_spec_with_extra_layer["params"]["layers"]} - prepare_vars_for_main_vmap_partial = functools.partial( - prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec - ) - vmap_func = nn.map_variables( - vmap_func, + def get_vmap_func_for_init(self): + """This vmap func is used to initialize the weights only on init.""" + + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + return nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0, "_overwrite_with_gradient": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + + def get_main_vmap_func_for_iterations(self): + """ + Returns main stage function vmapped by number of stages. + This becomes a vmap over a single layer instance if body_instance is a single layer, + else a set of layers if body_instance is a set of layers. + """ + + def func_to_vmap( + body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode + ): + weights = meta.remove_axis( + weights, + 0, + { + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + return nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + + def _run_weight_initialization( + self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode + ): + """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" + vmap_func = self.get_vmap_func_for_init() + + if self.config.num_pipeline_repeats > 1: + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0}, + split_rngs={"params": True, "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) + if example_position is not None + else None + ) + + example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) + stage_outputs = vmap_func( + self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode + ) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + out_sharding=self.output_sharding, + ) + + +class Pipeline(PipelineBaseLinen): + """Original Pipeline implementation.""" + + def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): + """Fetches the weights for the current repeat from the stages.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + + # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, + # only one circular entry per stage. + weights = meta.remove_axis(weights, 0, circular_metadata_params) + weights = pipeline_utils.remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + return jax.tree.map(gather_weights_for_stages_in, weights) + return jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + + def run_one_iteration( + self, + loop_state, + pipeline_weights, + positions, + segment_ids, + deterministic, + model_mode, + decoder_layer_instance, + logical_partition_spec=None, + ): + """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, + and update the loop state. + + Args: + loop_state: Dictionary containing the current state of the pipeline (state_io, shift, etc.) + positions: Positional encodings. + segment_ids: Segment IDs for packed sequences. + deterministic: Boolean indicating if execution should be deterministic (e.g. for dropout). + model_mode: Current model mode (train/predict). + logical_partition_spec: Logical partition specification for weights. + """ + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func_for_iterations() + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis(weights, 0, circular_metadata_params) + weights = pipeline_utils.remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + return jax.tree.map(gather_weights_for_stages_in, weights) + return jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + + prepare_vars_for_main_vmap_partial = functools.partial( + prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec + ) + vmap_func = nn.map_variables( + vmap_func, mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], mutable=True, trans_in_fn=prepare_vars_for_main_vmap_partial, @@ -756,42 +797,7 @@ def gather_weights_for_stages_in(w, spec=None): if self.config.scan_layers: stages_output = stages_output[0] - new_state = self.get_new_loop_state(stages_output, loop_state) - return new_state - - @staticmethod - def get_logical_spec_repeats_removed(full_logical): - """Returns a new logical spec with 'circular_repeats' removed.""" - if full_logical is None: - return None - - def _remove_from_spec(spec): - return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) - - return jax.tree.map(_remove_from_spec, full_logical) - - @staticmethod - def _remove_logically_partition(weights): - """Removes LogicallyPartitioned wrappers from the variables.""" - - def _remove_logically_partition_leaf(v): - return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v - - return jax.tree.map(_remove_logically_partition_leaf, weights, is_leaf=lambda v: isinstance(v, LogicallyPartitioned)) - - def all_gather_over_fsdp(self, variables, logical_partition_spec): - """Gathers FSDP partitioned variables to reconstruct them fully.""" - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - physical_partition_spec_no_fsdp = jax.tree.map( - self._remove_fsdp_from_physical_partition_spec, physical_partition_spec - ) - return jax.tree.map( - lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), - variables, - physical_partition_spec_no_fsdp, - ) + return self.advance_circular_buffers(stages_output, loop_state) @nn.compact def __call__( @@ -816,12 +822,12 @@ def __call__( ), out_sharding=self.input_sharding, ) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: - positions = self._maybe_shard_with_name(positions, ag_sharding) - positions = positions.reshape( + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) @@ -831,8 +837,7 @@ def __call__( position_idx = None if segment_ids is not None: - segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding) - segment_ids = segment_ids.reshape( + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) @@ -862,7 +867,7 @@ def __call__( ) if self.config.pipeline_fsdp_ag_once: - variables = self._remove_logically_partition(self.layers.variables) + variables = pipeline_utils.remove_logically_partition(self.layers.variables) all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) else: all_pipeline_weights = self.layers.variables @@ -903,6 +908,7 @@ def run_iteration_scannable(model, loop_state, xs): variable_carry.append("non_trainable") else: variable_broadcast.append("non_trainable") + run_all_iterations_scanned = nn.scan( run_iteration_scannable, variable_axes={"summaries": 0, "aux_loss": 0, "intermediates": 0, "hyper_params": 0}, @@ -921,15 +927,14 @@ def run_iteration_scannable(model, loop_state, xs): # the input final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] - final_output = jnp.reshape( + return jnp.reshape( final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), out_sharding=self.output_sharding, ) - return final_output -class CircularPipeline(PipelineBase): +class CircularPipeline(PipelineBaseLinen): """Implements an circular pipeline schedule with asynchronous weight prefetching. Circular pipelining reduces the pipeline "bubble" by interleaving multiple pipeline @@ -946,26 +951,7 @@ def init_states(self, inputs): (`state_io` and `shift`) and allocates the empty Buffer Sliding Window (BSW) that will hold the gathered FSDP weights. """ - shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) - - if self.config.pipeline_delay_activation_forwarding: - prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical) - else: - prev_outputs = None - - state_io = jnp.reshape( - inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding - ) - state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical) - - if self.use_circ_storage: - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding) - circ_storage_mover = shift - else: - circ_storage = None - circ_storage_mover = None + init_loop_state = super().init_states(inputs) def _init_empty_bsw_buffers(variables): # BSW requires two buffers (current and next) for the sliding window @@ -980,30 +966,11 @@ def _init_empty_bsw_buffers(variables): variables = pipeline_utils.remove_logically_partition(self.layers.variables) bsw = _init_empty_bsw_buffers(variables) - init_loop_state = { - "state_io": state_io, - "shift": shift, - "circ_storage": circ_storage, - "circ_storage_mover": circ_storage_mover, - "loop_iteration": 0, - "prev_outputs": prev_outputs, - } return init_loop_state, bsw - def gather_weights_across_stages_vmap(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): - """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" - - def _gather_single_repeat(x, repeat_id): - return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) - - gathered_weights_stage_dim = 0 - stage_weights = jax.vmap( - _gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim - )(weights, repeat_ids) - return stage_weights - def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.""" + xs = jnp.asarray(xs) # Safe casting for non-JAX arrays def _gather_one(x, i): idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) @@ -1016,111 +983,13 @@ def _gather_one(x, i): return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - def advance_circular_buffers(self, output, loop_state): - """Rotates pipeline activations to the next physical device stage. - - Uses `jax.lax.ppermute` to perform cross-device ring communication, shifting - the forward activations (`state_io` and `shift`) from stage $i$ to stage $i+1$. - """ - old_state_io = loop_state["state_io"] - old_circ_storage = loop_state["circ_storage"] - old_circ_storage_mover = loop_state["circ_storage_mover"] - loop_iteration = loop_state["loop_iteration"] - - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) - def _rotate_right(arr): - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) - def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) - - def _update_shift(output_in): - if self.config.num_pipeline_repeats == 1 or self.use_circ_storage: - return _shift_right(output_in) - else: - return _rotate_right(output_in) - - new_shift = _update_shift(output) - new_prev_outputs = None - - if self.use_circ_storage: - - def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): - rotated = _rotate_right(circ_storage_mover_in) - rotated = jnp.expand_dims(rotated, 1) - offset = ( - loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 - ) % self.config.num_pipeline_microbatches - return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) - - new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) - new_circ_storage_mover = output - else: - new_circ_storage = None - new_circ_storage_mover = None - - stream_buf_idx = loop_iteration % self.microbatches_per_stage - stream_slice = old_state_io[:, stream_buf_idx] - - def _rotate_left(arr, stage_size): - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - check_vma=True, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) - stream_slice = jnp.expand_dims(stream_slice, 1) - return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) - new_loop_state = { - "state_io": new_state, - "shift": new_shift, - "circ_storage": new_circ_storage, - "circ_storage_mover": new_circ_storage_mover, - "loop_iteration": loop_iteration + 1, - "prev_outputs": new_prev_outputs, - } - return new_loop_state - - def realign_output_microbatches(self, output): - """Reorders the output tensor to reverse the circular shifts applied during execution. - - Because the pipeline operates circularly, the output microbatches are shifted - out of order by the time the final stage is completed. This rolls them back - into their original sequential layout. - """ - microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage - output = jnp.roll(output, shift=-microbatch_0_idx, axis=1) - output = self._maybe_shard_with_logical(output, self.state_io_logical) - return output - def fetch_active_stage_weights(self, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None): """The module fetches the actively prefetched weights from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers. """ - pipeline_weights = self.get_current_weights_from_bsw( + return self.get_current_weights_from_bsw( bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing ) - return pipeline_weights def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None): """Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer.""" @@ -1144,8 +1013,19 @@ def select_weights_from_bsw(bsw, repeat_id): "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } - weights = meta.remove_axis(weights, 0, circular_metadata_params) - return weights + return meta.remove_axis(weights, 0, circular_metadata_params) + + def gather_weights_across_stages_vmap(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" + + def _gather_single_repeat(x, repeat_id): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) + + gathered_weights_stage_dim = 0 + stage_weights = jax.vmap( + _gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim + )(weights, repeat_ids) + return stage_weights def from_all_variables_to_repeat_weights(self, weights, loop_iteration): """Gathers weights corresponding to the repeat IDs for current iteration.""" @@ -1166,8 +1046,7 @@ def gather_weights_for_stages_in(w): "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } - repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params) - return repeat_weights + return meta.remove_axis(weights, 0, circular_metadata_params) def from_repeat_weights_to_bsw( self, @@ -1181,18 +1060,13 @@ def from_repeat_weights_to_bsw( axes_to_remove = ["fsdp", "fsdp_transpose", "context"] bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove) - def _from_repeat_weights_to_bsw_shardmap( - repeat_weights, - physical_partition_spec, - axes_to_gather, - ): + def _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather): repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) # Dynamically gather the index pytrees for all specified axes axis_indices_dict = { axis: pipeline_utils.get_mesh_axis_dim_indices(physical_partition_spec, axis) for axis in axes_to_gather } - axis_names = list(axis_indices_dict.keys()) axis_pytrees = list(axis_indices_dict.values()) @@ -1211,7 +1085,6 @@ def should_skip_gather(axis_name, path_keys): check_vma=False, ) def _shard_map_gather_weights(sharded_weights, indices_pytrees_list): - # Renamed to clarify we are gathering a single tensor iteratively along requested axes def _gather_tensor_along_axes(path, x, *indices): path_keys = [getattr(p, "key", str(p)) for p in path] @@ -1226,9 +1099,7 @@ def _gather_tensor_along_axes(path, x, *indices): return _shard_map_gather_weights(repeat_weights, axis_pytrees) - def _from_repeat_weights_to_bsw_hint( - repeat_weights, - ): + def _from_repeat_weights_to_bsw_hint(repeat_weights): def _apply_sharding_hint(weight, pspec): sharding_name = NamedSharding(self.mesh, pspec) return maybe_shard_with_name( @@ -1290,8 +1161,7 @@ def run_one_iteration(self, loop_state, bsw, positions, segment_ids, determinist if self.config.scan_layers: stages_output = stages_output[0] - new_state = self.advance_circular_buffers(stages_output, loop_state) - return new_state + return self.advance_circular_buffers(stages_output, loop_state) @nn.compact def __call__( @@ -1317,8 +1187,7 @@ def __call__( ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: - positions = self._maybe_shard_with_name(positions, ag_sharding) - positions = positions.reshape( + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) @@ -1328,8 +1197,7 @@ def __call__( position_idx = None if segment_ids is not None: - segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding) - segment_ids = segment_ids.reshape( + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) @@ -1352,19 +1220,18 @@ def __call__( logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec) - def run_iteration_scannable(model, loop_state, bsw): - return ( - model.run_one_iteration( - loop_state, - bsw, - positions, - segment_ids, - deterministic, - model_mode, - logical_partition_spec=logical_partition_spec, - ), - None, + def run_iteration_scannable(model, loop_state, bsw, weights): + new_loop_state = model.run_one_iteration( + loop_state, + bsw, + weights, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec=logical_partition_spec, ) + return (new_loop_state, bsw, weights), None if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( @@ -1372,7 +1239,6 @@ def run_iteration_scannable(model, loop_state, bsw): prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) - # base scannable function used twice for real and bubble runs base_scannable = functools.partial( pipeline_utils.create_pipeline_stage, @@ -1405,18 +1271,736 @@ def run_iteration_scannable(model, loop_state, bsw): (loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles) final_output = self.realign_output_microbatches(loop_state["state_io"]) - final_output = jnp.reshape( + return jnp.reshape( final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), out_sharding=self.output_sharding, ) - return final_output -def create_pipeline(config: Config, layers: nn.Module, mesh: Mesh, remat_policy: Any = None) -> PipelineBase: +def create_pipeline( + config: Config, layers: nn.Module, mesh: Mesh, remat_policy: Any = None +) -> Pipeline | CircularPipeline: """Factory function to instantiate the correct Pipeline module based on config.""" - if config.pipeline_fsdp_ag_per_repeat: return CircularPipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy) - return Pipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy) + + +class NNXPipelineBase(nnx.Module, PipelineSharedMixin): + """Base module that implements shared pipelining logic across stages for NNX.""" + + def __init__( + self, + config: Config, + stage_factory: Any, + mesh: Mesh, + remat_policy: Any = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.remat_policy = remat_policy + self._setup_pipeline_attributes() + + def build_batched_rngs(shape): + max_logging.log(f"Building batched RNGs with shape {shape}...") + kwargs = {} + rng_state = nnx.state(rngs, nnx.RngState) + leaves, _ = jax.tree_util.tree_flatten_with_path(rng_state) + for path, key in leaves: + stream_name = getattr(path[0], "key", str(path[0])) + if not jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): + key = jax.random.key(key) + num_splits = int(np.prod(shape)) + flat_keys = jax.random.split(key, num_splits) + kwargs[stream_name] = flat_keys.reshape(shape + key.shape) + print(f"DEBUG: build_batched_rngs created kwargs keys: {kwargs.keys()}") + return nnx.Rngs(**kwargs) + + def create_stage_fn(r): + stage = stage_factory(r) + # Split into (GraphDef, Param State, Rest of State) + return nnx.split(stage, nnx.Param, ...) + + vmap_stages = nnx.vmap( + create_stage_fn, + in_axes=0, + out_axes=(None, 0, 0), + axis_name=self.spmd_axis_name, + transform_metadata={nnx.PARTITION_NAME: "layers"}, + ) + + if self.config.num_pipeline_repeats > 1: + vmap_repeats = nnx.vmap( + vmap_stages, + in_axes=0, + out_axes=(None, 0, 0), + transform_metadata={nnx.PARTITION_NAME: "circular_repeats"}, + ) + batched_rngs = build_batched_rngs((self.config.num_pipeline_repeats, self.num_stages)) + graphdef, params, rest = vmap_repeats(batched_rngs) + else: + batched_rngs = build_batched_rngs((self.num_stages,)) + graphdef, params, rest = vmap_stages(batched_rngs) + + # Merge the batched states back into the module + self.layers = nnx.merge(graphdef, params, rest) + + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + state = nnx.state(self.layers) + + def get_spec(x): + if not isinstance(x, nnx.VariableState): + return None + if isinstance(x.value, nn.spmd.LogicallyPartitioned): + return x.value.partitions + metadata = x.get_metadata() + sharding = metadata.get("sharding") + if sharding and hasattr(sharding, "spec"): + return sharding.spec + return None + + return jax.tree.map(get_spec, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + + def get_main_vmap_func_for_iterations(self): + def func_to_vmap(graph, state, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + module = nnx.merge(graph, state) + out = module(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + return out, nnx.state(module) + + return nnx.vmap( + func_to_vmap, + in_axes=(None, 0, 0, 0, 0, None, None), + out_axes=(0, 0), + axis_name=self.spmd_axis_name, + ) + + +class NNXPipeline(NNXPipelineBase): + """Original Pipeline implementation adapted for NNX.""" + + def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): + if self.config.num_pipeline_repeats > 1: + return self.get_current_repeat_from_stages( + pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + ) + return pipeline_weights + + def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): + """Fetches the weights for the current repeat from the stages.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def _gather_repeat(w_tree, rep_id): + # w_tree: each leaf has (repeats, ...) shape — stages dim removed by outer vmap + # rep_id: scalar — this stage's repeat index + def _gather_leaf(w): + if w is None: + return None + sliced = jax.lax.dynamic_slice_in_dim(w, rep_id, 1, axis=0) + return jnp.squeeze(sliced, axis=0) + + return jax.tree.map(_gather_leaf, w_tree) + + return jax.vmap(_gather_repeat, in_axes=(1, 0), out_axes=0)(weights, repeat_ids) + + def run_one_iteration( + self, + loop_state, + pipeline_weights_graph, + pipeline_weights_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec=None, + ): + """Executes the logic for a single microbatch iteration, including routing inputs and weights, and advancing buffers.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func_for_iterations() + + stage_weights_state = self.get_current_stage_weights( + pipeline_weights_state, loop_iteration, physical_partition_spec=physical_partition_spec + ) + + stages_output, updated_stage_weights_state = vmap_func( + pipeline_weights_graph, + stage_weights_state, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + + if self.config.scan_layers: + stages_output = stages_output[0] + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def _tree_scatter_update(fw_tree, uw_tree): + r_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + + def _update_one_stage(f_s_tree, u_s_tree, r_id): + def _update_leaf(f_s, u_s): + if f_s is None or u_s is None: + return f_s + return jax.lax.dynamic_update_slice_in_dim(f_s, jnp.expand_dims(u_s, 0), r_id, axis=0) + + return jax.tree.map(_update_leaf, f_s_tree, u_s_tree) + + updated_fw_tree = jax.vmap(_update_one_stage, in_axes=(1, 0, 0), out_axes=1)(fw_tree, uw_tree, r_ids) + return jax.tree.map( + lambda x: self.shard_dim_by_stages(x, 1, physical_partition_spec=None, is_stage_weight=False), updated_fw_tree + ) + + pipeline_weights_state = _tree_scatter_update(pipeline_weights_state, updated_stage_weights_state) + else: + pipeline_weights_state = updated_stage_weights_state + + new_state = self.advance_circular_buffers(stages_output, loop_state) + return new_state, pipeline_weights_state + + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=MODEL_MODE_TRAIN, + logical_partition_spec=None, + ) -> jnp.ndarray: + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ), + out_sharding=self.input_sharding, + ) + ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) + if positions is not None: + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + if segment_ids is not None: + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + loop_state = self.init_states(inputs) + + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + total_iterations = real_iterations + bubble_iterations + + logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) + + layers_graph, layers_state = nnx.split(self.layers) + + def is_lp(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def unbox_val(x): + return x.value if is_lp(x) else x + + layers_state = jax.tree.map(unbox_val, layers_state, is_leaf=is_lp) + + if self.config.pipeline_fsdp_ag_once: + layers_state = self.all_gather_over_fsdp(layers_state, logical_partition_spec) + + def is_static_param(path, v): + return isinstance(v, nnx.Param) or type(v).__name__ == "_overwrite_with_gradient" + + _, layers_params, layers_metrics, layers_mutables = nnx.split(layers_state, is_static_param, nnx.Intermediate, ...) + print(f"DEBUG: layers_mutables keys before scan are {layers_mutables.keys()}") + + def scan_body(carry, _): + current_loop_state, current_layer_mutables = carry + current_layer_state = nnx.State.merge(layers_params, layers_metrics, current_layer_mutables) + + new_loop_state, new_layer_state = self.run_one_iteration( + current_loop_state, + layers_graph, + current_layer_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec, + ) + + _, _, new_layer_metrics, new_layer_mutables = nnx.split(new_layer_state, is_static_param, nnx.Intermediate, ...) + return (new_loop_state, new_layer_mutables), new_layer_metrics + + if self.config.set_remat_policy_on_pipeline_iterations: + scan_body = jax.checkpoint( + scan_body, policy=self.get_pipeline_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations + ) + + if self.config.scan_pipeline_iterations: + (loop_state, final_layer_mutables), stacked_metrics = jax.lax.scan( + scan_body, (loop_state, layers_mutables), None, length=total_iterations + ) + else: + current_carry = (loop_state, layers_mutables) + metrics_history = [] + for _ in range(total_iterations): + current_carry, step_metrics = scan_body(current_carry, None) + metrics_history.append(step_metrics) + loop_state, final_layer_mutables = current_carry + stacked_metrics = jax.tree.map(lambda *xs: jnp.stack(xs), *metrics_history) if metrics_history else layers_metrics + + final_layer_state = nnx.State.merge(layers_params, stacked_metrics, final_layer_mutables) + self.layers = nnx.merge(layers_graph, final_layer_state, copy=True) + + final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) + return jnp.reshape( + final_output, + (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), + out_sharding=self.output_sharding, + ) + + +class NNXCircularPipeline(NNXPipelineBase): + """NNX Implementation of a circular pipeline schedule with asynchronous weight prefetching. + + Inherits directly from NNXPipelineBase to leverage its native nnx.vmap setup for + pipeline variables (stages and circular_repeats). Uses pure jax.lax.scan and + jax.checkpoint for execution to avoid banned NNX execution transforms. + """ + + def init_states(self, inputs): + """Initializes pipeline execution state and Empty BSW buffers.""" + loop_state = super().init_states(inputs) + + # Only prefetch parameters to avoid PyTree mismatches with dynamically created rngs/mutables. + weights = nnx.state(self.layers, nnx.Param) + + def get_single_repeat_shape(x): + if x is None: + return None + if self.config.num_pipeline_repeats > 1: + if isinstance(x, jax.ShapeDtypeStruct): + # Manually calculate sliced shape for abstract tracers + new_shape = (1,) + x.shape[1:] + return jax.ShapeDtypeStruct(new_shape, x.dtype) + return jnp.zeros_like(x[0]) + return jnp.zeros_like(x) + + bsw = ( + jax.tree.map(get_single_repeat_shape, weights), + jax.tree.map(get_single_repeat_shape, weights), + ) + + return loop_state, bsw + + def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): + """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.""" + if xs is None: + return None + + xs = jnp.asarray(xs) + ndim = xs.ndim + + def _gather_one(x, i): + idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) + replicated_sharding = NamedSharding(self.mesh, P()) + return x.at[idx].get(out_sharding=replicated_sharding) + + return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) + + def gather_weights_across_stages_vmap(self, weights_state, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" + + def _gather_repeat(w_tree, rep_id): + def _slice_leaf(w): + if w is None: + return None + sliced = jax.lax.dynamic_slice_in_dim(w, rep_id, 1, axis=repeat_dim_in_weights) + return jnp.squeeze(sliced, axis=repeat_dim_in_weights) + + return jax.tree.map(_slice_leaf, w_tree) + + return jax.vmap(_gather_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=0)(weights_state, repeat_ids) + + def from_all_variables_to_repeat_weights(self, weights_state, loop_iteration): + """Slices out the specific repeat's weights from the full weights state.""" + if self.config.num_pipeline_repeats == 1: + return weights_state + + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + return self.gather_weights_across_stages_vmap( + weights_state, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ) + + def from_repeat_weights_to_bsw(self, repeat_weights, physical_partition_spec): + """Executes FSDP-like all-gathers to fully materialize a block of weights for BSW.""" + axes_to_remove = ["fsdp", "fsdp_transpose"] + if physical_partition_spec is not None: + bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove) + else: + bsw_pps = None + + def _apply_sharding_hint(weight, pspec): + if pspec is None or weight is None: + return weight + sharding_name = NamedSharding(self.mesh, pspec) + return maybe_shard_with_name( + weight, + sharding_name, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=0, + ) + + if bsw_pps is None: + return repeat_weights + + # Flatten specs to a list aligned with repeat_weights' leaf traversal order. + # Single-tree map avoids nnx.Variable mutation (TraceContextError inside scan). + def is_spec_leaf(x): + return isinstance(x, P) or x is None + + spec_leaves = jax.tree_util.tree_leaves(bsw_pps, is_leaf=is_spec_leaf) + spec_iter = iter(spec_leaves) + return jax.tree.map(lambda w: _apply_sharding_hint(w, next(spec_iter)), repeat_weights) + + def weight_prefetching(self, weights_state, physical_partition_spec, loop_iteration): + """Triggers asynchronous FSDP-like all-gathers for current and next pipeline steps.""" + cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights_state, loop_iteration) + nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights_state, loop_iteration + 1) + bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec) + bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec) + return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") + + def fetch_active_stage_weights(self, bsw, loop_iteration, physical_partition_spec=None): + """The module fetches the actively prefetched weights + from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers. + """ + return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec) + + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec): + """Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer.""" + + # Strip nnx.Variable wrappers from pps to match the structure of bsw (which is raw arrays) + def _is_var(x): + return isinstance(x, nnx.Variable) + + physical_partition_spec = jax.tree.map(lambda x: x.value, physical_partition_spec, is_leaf=_is_var) + + bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) + # Ensure bsw_pps matches the None-structure of bsw. + # A leaf should only be None in pps if it's None in BOTH buffers. + bsw_dict_0 = nnx.to_pure_dict(bsw[0]) + bsw_dict_1 = nnx.to_pure_dict(bsw[1]) + pps_dict = nnx.to_pure_dict(bsw_pps) + + def is_really_none(x): + return x is None + + bsw_pps = jax.tree.map( + lambda b0, b1, p: p if (not is_really_none(b0) or not is_really_none(b1)) else None, + bsw_dict_0, + bsw_dict_1, + pps_dict, + ) + + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + stage0_repeat_id = jnp.maximum(loop_iteration, 0) // self.config.num_pipeline_microbatches + + # Only use shard_map when there are actual FSDP-sharded params (non-None specs). + # If all specs are None (no FSDP), fall through to the vmap path. + pps_leaves_for_check = jax.tree_util.tree_leaves(bsw_pps, is_leaf=lambda x: isinstance(x, (P, type(None)))) + has_fsdp_params = any(leaf is not None for leaf in pps_leaves_for_check) + + if bsw_pps is not None and has_fsdp_params: + + @jax.shard_map(mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), out_specs=bsw_pps, check_vma=False) + def select_weights_from_bsw(bsw_inner, repeat_id): + def _select_leaf(x, y, p): + if p is None: + return None + return jax.lax.select(repeat_id[0] == stage0_repeat_id, y, x) + + return jax.tree.map(_select_leaf, bsw_inner[0], bsw_inner[1], bsw_pps) + + # Convert bsw data to dicts to match bsw_pps structure for shard_map. + bsw_data = (bsw_dict_0, bsw_dict_1) + weights = select_weights_from_bsw(bsw_data, repeat_ids) + else: + + def select_weights_from_bsw(bsw_inner, repeat_id): + def _select_leaf(x, y): + if x is None: + return None + return jax.lax.select(repeat_id == stage0_repeat_id, y, x) + + return jax.tree.map(_select_leaf, bsw_inner[0], bsw_inner[1]) + + weights = jax.vmap(select_weights_from_bsw, in_axes=((0, 0), 0), out_axes=0)(bsw, repeat_ids) + return weights + + def run_one_iteration( + self, + loop_state, + bsw, + pipeline_weights_graph, + layers_params, + layers_metrics, + current_layer_mutables, + positions, + segment_ids, + deterministic, + model_mode, + params_physical_partition_spec, + ): + """Executes the forward/backward logic for a single microbatch inside the circular pipeline.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + + stages_positions = self.gather_microbatch_inputs_vmap(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = ( + self.gather_microbatch_inputs_vmap(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + ) + + vmap_func = self.get_main_vmap_func_for_iterations() + + # 1. Fetch prefetched parameters from BSW + stage_params = self.fetch_active_stage_weights( + bsw, + loop_iteration, + physical_partition_spec=params_physical_partition_spec, + ) + + # 2. Fetch mutables/RNGs directly from the current layer state components + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + stage_metrics_fetched = self.gather_weights_across_stages_vmap( + layers_metrics, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ) + stage_mutables_fetched = self.gather_weights_across_stages_vmap( + current_layer_mutables, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ) + + # 3. Merge them into the complete state required for the forward pass + stage_weights_state = nnx.State.merge(stage_params, stage_metrics_fetched, stage_mutables_fetched) + + stages_output, updated_stage_weights_state = vmap_func( + pipeline_weights_graph, + stage_weights_state, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + + if self.config.scan_layers: + stages_output = stages_output[0] + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def _tree_scatter_update(fw_tree, uw_tree): + r_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + + def _update_one_stage(f_s_tree, u_s_tree, r_id): + def _update_leaf(f_s, u_s): + if f_s is None or u_s is None: + return f_s + return jax.lax.dynamic_update_slice_in_dim(f_s, jnp.expand_dims(u_s, 0), r_id, axis=0) + + return jax.tree.map(_update_leaf, f_s_tree, u_s_tree) + + updated_fw_tree = jax.vmap(_update_one_stage, in_axes=(1, 0, 0), out_axes=1)(fw_tree, uw_tree, r_ids) + return jax.tree.map( + lambda x: self.shard_dim_by_stages(x, 1, physical_partition_spec=None, is_stage_weight=False), updated_fw_tree + ) + + def is_static_param(path, v): + return isinstance(v, nnx.Param) or type(v).__name__ == "_overwrite_with_gradient" + + # We only need to update the metrics and mutables in the carry, as parameters are handled by AD + # We extract only the non-static parts from the updated stage state + _, _, updated_stage_metrics, updated_stage_mutables = nnx.split( + updated_stage_weights_state, is_static_param, nnx.Intermediate, ... + ) + updated_stage_non_params = nnx.State.merge(updated_stage_metrics, updated_stage_mutables) + + current_layer_state = nnx.State.merge(layers_metrics, current_layer_mutables) + new_layer_state = _tree_scatter_update(current_layer_state, updated_stage_non_params) + else: + # If not repeats, we just return the full state (params will be ignored in carry) + new_layer_state = updated_stage_weights_state + + new_state = self.advance_circular_buffers(stages_output, loop_state) + return new_state, new_layer_state + + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=MODEL_MODE_TRAIN, + logical_partition_spec=None, + ) -> jnp.ndarray: + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ), + out_sharding=self.input_sharding, + ) + + ag_sharding = NamedSharding(self.mesh, P(None, None)) + if positions is not None: + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + if segment_ids is not None: + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + loop_state, _ = self.init_states(inputs) + + physical_partition_spec = logical_to_mesh( + logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) + + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + total_iterations = real_iterations + bubble_iterations + + layers_graph, layers_state = nnx.split(self.layers) + + def is_lp(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def unbox_val(x): + return x.value if is_lp(x) else x + + layers_state = jax.tree.map(unbox_val, layers_state, is_leaf=is_lp) + + def is_static_param(path, v): + return isinstance(v, nnx.Param) or type(v).__name__ == "_overwrite_with_gradient" + + _, layers_params, layers_metrics, layers_mutables = nnx.split(layers_state, is_static_param, nnx.Intermediate, ...) + + # Filter physical_partition_spec to only contain keys that exist in layers_params. + # This prevents structural mismatches in shard_map when layers_state has more keys (like dropout). + def filter_to_match(path, _): + try: + spec = physical_partition_spec + for p in path: + spec = spec[p.key if hasattr(p, "key") else p] + return spec + except (KeyError, TypeError, AttributeError): + return None + + params_physical_partition_spec = jax.tree_util.tree_map_with_path(filter_to_match, layers_params) + + def scan_body(carry, _): + current_loop_state, current_layer_mutables = carry + + # 1. Async FSDP Prefetch into Buffer Sliding Window. Only operate on parameters. + # We compute next_bsw locally; it's not carried to save memory. + next_bsw = self.weight_prefetching( + layers_params, params_physical_partition_spec, current_loop_state["loop_iteration"] + ) + + # 2. Run Forward & State Shift + new_loop_state, new_layer_mutables = self.run_one_iteration( + current_loop_state, + next_bsw, + layers_graph, + layers_params, + layers_metrics, + current_layer_mutables, + positions, + segment_ids, + deterministic, + model_mode, + params_physical_partition_spec, + ) + + # 3. Extract metrics (which are returned separately in scan) + # Since run_one_iteration now returns (new_state, new_mutables), + # we need to extract metrics from new_mutables if they were added. + # However, metrics are usually added via self.sow(nnx.Intermediate, ...) + # which results in nnx.Intermediate nodes in the state. + # To keep it simple, we split them here. + _, _, new_layer_metrics, new_layer_mutables = nnx.split(new_layer_mutables, is_static_param, nnx.Intermediate, ...) + return (new_loop_state, new_layer_mutables), new_layer_metrics + + if self.config.set_remat_policy_on_pipeline_iterations: + scan_body = jax.checkpoint( + scan_body, policy=self.get_pipeline_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations + ) + + # Memory Efficient Execution via pure JAX scan + if self.config.scan_pipeline_iterations: + (loop_state, final_layer_mutables), stacked_metrics = jax.lax.scan( + scan_body, (loop_state, layers_mutables), None, length=total_iterations + ) + else: + current_carry = (loop_state, layers_mutables) + metrics_history = [] + for _ in range(total_iterations): + current_carry, step_metrics = scan_body(current_carry, None) + metrics_history.append(step_metrics) + loop_state, final_layer_mutables = current_carry + stacked_metrics = jax.tree.map(lambda *xs: jnp.stack(xs), *metrics_history) if metrics_history else layers_metrics + + final_layer_state = nnx.State.merge(layers_params, stacked_metrics, final_layer_mutables) + + # Skip direct mutation of 'layers' during compilation to avoid TraceContextError. + is_tracing = any(isinstance(x, jax.core.Tracer) for x in jax.tree_util.tree_leaves(final_layer_state)) + if not is_tracing: + nnx.update(self.layers, final_layer_state) + + final_output = self.realign_output_microbatches(loop_state["state_io"]) + return jnp.reshape( + final_output, + (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), + out_sharding=self.output_sharding, + ) + + +def create_nnx_pipeline( + config: Config, stage_factory: Any, mesh: Mesh, remat_policy: Any = None, *, rngs: nnx.Rngs +) -> NNXPipeline | NNXCircularPipeline: + """Factory function to instantiate the NNX Pipeline module.""" + if config.pipeline_fsdp_ag_per_repeat: + return NNXCircularPipeline( + config=config, stage_factory=stage_factory, mesh=mesh, remat_policy=remat_policy, rngs=rngs + ) + return NNXPipeline(config=config, stage_factory=stage_factory, mesh=mesh, remat_policy=remat_policy, rngs=rngs) diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py new file mode 100644 index 0000000000..9ef0e6dffd --- /dev/null +++ b/src/maxtext/layers/train_state_nnx.py @@ -0,0 +1,48 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" The NNX Unified TrainState. """ + +from typing import Any + +from flax import nnx + + +class TrainStateNNX(nnx.Module): + """ + A unified container for NNX models and optimizers. + This replaces Linen's TrainState for checkpointing. + + Linen TrainState pytree: + {“params”: {...}, “opt_state”: {}...} + TrainStateNNX state pytree: + {“model”: {...}, “optimizer”: {“opt_state”: {...}} + """ + + def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + self.model = model + self.optimizer = optimizer + + def apply_gradients(self, grads: Any): + """ + Mimics the Linen apply_gradients function. + Updates the optimizer state, applies updates to parameters, + and increments the step counter. + """ + if self.optimizer is None: + raise RuntimeError( + "Cannot call apply_gradients on a TrainStateNNX initialized without an optimizer. " + "This usually happens when the state was created for inference only." + ) + self.optimizer.update(self.model, grads) diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 58a0a2db8f..68445a87d5 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -28,6 +28,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -130,6 +131,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -181,7 +184,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index c28020d781..9a93e66a1c 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -139,6 +140,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -193,7 +195,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index eb15747fc2..5ba630adc3 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -962,7 +962,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -987,7 +987,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index ff8cdde8a8..2c970d313a 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -557,7 +557,13 @@ def map_to_pspec(data): ) } - if optimizer is not None: + # Only restore optimizer state if it was actually saved in this checkpoint. + # PeftTrainer.save() doesn't pass the optimizer, so older checkpoints may + # only contain model_params. + ckpt_optimizer_path = self._checkpoint_manager.directory / str(step) / "optimizer_state" + checkpoint_has_optimizer = optimizer is not None and ckpt_optimizer_path.exists() + + if checkpoint_has_optimizer: optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState) opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state) cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore( @@ -571,7 +577,7 @@ def map_to_pspec(data): ) nnx.update(target_model, restored.model_params) - if optimizer is not None: + if checkpoint_has_optimizer: nnx.update(optimizer, restored.optimizer_state) metadata = self._checkpoint_manager.metadata(step) diff --git a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py index 7cc8f5b658..c7f6bd4740 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py +++ b/src/maxtext/trainers/post_train/sft/train_sft_deprecated.py @@ -85,7 +85,7 @@ def train_loop(config, recorder, state=None): compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) data_loader = DataLoader(config, mesh, data_iterator, recorder) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 2c374ba651..1bde773c62 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -34,8 +34,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -67,6 +68,7 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss @@ -76,8 +78,10 @@ VertexTensorboardManager, _vertex_tb_is_stub = vertex_tensorboard_modules() -def get_first_step(state): - return int(state.step) +def get_first_step(model, state): + if isinstance(model, nn.Module): + return int(state.step) + return int(state.optimizer.step.get_value()) # ----------------------------------------------------------------------------- @@ -89,11 +93,11 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -172,7 +176,7 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): total_loss = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -183,7 +187,11 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + # Capture NNX intermediates (MoE losses, hidden states, etc.) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + + if config.num_vocab_tiling > 1: + raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. @@ -295,62 +303,98 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True): return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - params = state.params + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + if config.use_dpo: + raise NotImplementedError("DPO for NNX modules has not been implemented.") + state = nnx.merge(model, state) # reconstruct TrainStateNNX + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, - ) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -361,6 +405,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] total_weights = aux["total_weights"] moe_lb_loss = aux["moe_lb_loss"] @@ -369,43 +415,65 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux["moe_bias_updates"] mtp_loss = aux["mtp_loss"] - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) + ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + new_state = state.apply_gradients(grads=grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: grads = raw_grads - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - new_state = state.apply_gradients(grads=grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() scalar_metrics = { "learning/loss": loss, @@ -416,8 +484,9 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) @@ -427,34 +496,41 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + _, model_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] metrics = { "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + return nnx.state(new_state), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" - - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -478,7 +554,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -500,41 +576,59 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training prof = profiler.Profiler(config, offset_step=start_step) metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -544,57 +638,60 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 78392a388a..abf9817ab4 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -27,8 +27,10 @@ from typing import Sequence from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -36,6 +38,7 @@ from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode from maxtext.layers import quantizations +from maxtext.layers import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.trainers.diloco import diloco @@ -44,6 +47,8 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils # pylint: disable=too-many-positional-arguments @@ -89,11 +94,35 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) - model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + if config.pure_nnx: + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) + else: + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -103,19 +132,48 @@ def get_shaped_inputs(topology_mesh, config): _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) shaped_rng = jax.ShapeDtypeStruct(example_rng.shape, example_rng.dtype) - # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, example_rng, topology_mesh - ) + if config.pure_nnx: + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) - # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, example_rng, topology_mesh) + init_state_fn = create_train_state_fn + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) + + # Shaped state + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) + + if config.pure_nnx: + # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. + logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + # For NNX, get_functional_train_with_signature expects the graphdef (static structure), + # not the raw model — mirroring how the training loop does nnx.split(train_state). + with nn_partitioning.axis_rules(config.logical_axis_rules): + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) + model = graphdef + else: + # unsharded logical annotations + logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) - shaped_train_args = (abstract_state, shaped_batch, shaped_rng) + if config.pure_nnx: + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng + else: + shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -253,7 +311,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None) @@ -281,12 +341,20 @@ def main(argv: Sequence[str]) -> None: # print weights sharding info under debug sharding mode if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) - maxtext_utils.print_shardings_params( - shaped_train_args[0].params, - state_mesh_shardings.params, - topology_mesh, - logical_annotations.params, - ) + if config.pure_nnx: + maxtext_utils.print_shardings_params( + shaped_train_args[0], + state_mesh_shardings, + topology_mesh, + logical_annotations, + ) + else: + maxtext_utils.print_shardings_params( + shaped_train_args[0].params, + state_mesh_shardings.params, + topology_mesh, + logical_annotations.params, + ) # Compile print("Jitting and compiling train step...", flush=True) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 7c520cc470..2fd14b87a2 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -22,6 +22,7 @@ The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16. """ +import functools import os.path from typing import Sequence @@ -42,8 +43,6 @@ from maxtext.utils import max_utils from maxtext.utils import maxtext_utils -Transformer = models.transformer_as_linen - def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" @@ -93,12 +92,20 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( - model, None, tx, config, rng, mesh, checkpoint_manager + None, config, mesh, checkpoint_manager, init_state_fn ) num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") @@ -109,7 +116,10 @@ def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" # Model and Optimizer definition quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh, quant, MODEL_MODE_TRAIN) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e4cad14906..ee9ad6dc1b 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,20 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: + is_nnx = isinstance(model, nnx.Module) + + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +97,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["total_loss"] acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +145,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -134,6 +162,9 @@ def reshape_to_microbatch_accumulations(batch_arr): raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 4be05ff7e1..36e612a3f9 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -30,6 +30,7 @@ """ +import functools import os from typing import Any, Sequence @@ -174,12 +175,19 @@ def __init__(self, config: Any, rng: PRNGKeyType): # Model and quantization config self.quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state( - model, None, self.config, self.rng, self._mesh, False - ) + if self.config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + if self.config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + + self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) def load_and_quantize(self) -> None: """ diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 03095edd73..24099ef22a 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,6 +14,7 @@ """ Common LoRA utils needed to support LoRA adapters.""" +from functools import partial import json import jax @@ -166,7 +167,12 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, rng, mesh, True) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 675f920357..1adf9be46d 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -18,25 +18,27 @@ import functools import pickle import os +from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager -from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.configs import pyconfig +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing @@ -45,6 +47,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -86,16 +89,57 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, sha return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode) +def to_hashable(x): + """Recursively converts unhashable containers (dict, list) into hashable tuples.""" + if isinstance(x, dict): + return tuple(sorted((k, to_hashable(v)) for k, v in x.items())) + elif isinstance(x, (list, tuple)): + return tuple(to_hashable(v) for v in x) + return x + + def get_functional_train_with_signature( train_step, data_sharding, state_mesh_shardings, model, config, params_shardings=None ): """Get the shardings (both state and data) for `train_step`.""" - functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) + if isinstance(model, nnx.Module): + # For native NNX models, we must avoid unhashable State objects in the JIT closure. + # We perform the merge entirely INSIDE the traced function. + # We use a full split to ensure the graphdef matches the full state (133 leaves). + graphdef, _ = nnx.split(model) + + def nnx_train_step_wrapper(graphdef, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): + # Re-materialize the model from graphdef and the dynamic state. + # We extract the raw values but keep the nnx.State structure and leaf types (Param, RngKey, etc.) + # to satisfy nnx.merge's strict type checking. + def safe_unbox(x): + # If it's a VariableState, extract value. If it's already unboxed (e.g. a tracer), keep it. + return x.value if hasattr(x, "value") else x + + # state.params is already an nnx.State with the correct leaf types. + # We just need to ensure the leaves themselves contain the raw array values (unboxed from .value if needed). + # unboxed from .value if needed). + # nnx.State is a valid PyTree, so we can use jax.tree.map directly. + # But jax.tree.map returns a dict, so we must re-wrap in nnx.State. + unboxed_state = nnx.State(jax.tree.map(safe_unbox, state.params)) + model = nnx.merge(graphdef, unboxed_state) + return train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng) + + # We make graphdef, config, etc. static + functional_train = functools.partial(nnx_train_step_wrapper, graphdef, config, state_mesh_shardings, params_shardings) + static_argnums = () # We partial out the static argnums + donate_argnums = 0 # State is the first argument to the partialed function + else: + functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # State is the first dynamic argument + functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics - static_argnums = () # We partial out the static argnums of model and config - donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums @@ -103,7 +147,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -196,8 +243,11 @@ def get_train_input_output_trees(func, input_args, input_kwargs): serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file) shaped_batch = get_shaped_batch(config) - example_rng = jax.random.PRNGKey(0) - shaped_input_args = (state, shaped_batch, example_rng) + if config.pure_nnx: + shaped_input_args = (state, shaped_batch) + else: + example_rng = jax.random.PRNGKey(0) + shaped_input_args = (state, shaped_batch, example_rng) shaped_input_kwargs = {} in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree, execution_devices=execution_devices) @@ -992,15 +1042,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1021,17 +1071,26 @@ def init_initial_state(model, tx, config, is_training, key): # Reference: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html params_key, dropout_key, aqt_key = jax.random.split(key, 3) - model_vars = model.init( - {"params": params_key, "dropout": dropout_key, "aqt": aqt_key}, - np.ones(input_shape, dtype=jnp.int32), - np.ones(input_shape, dtype=jnp.int32), - encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, - encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, - # nnx_method="no_op", - ) + if isinstance(model, nnx.Module): + # For NNX models, they are already initialized (or abstractly initialized). + # We extract the FULL state to ensure it matches the graphdef leaf count for nnx.merge. + # Non-differentiable parts will be filtered during the actual grad call. + model_vars = nnx.state(model) + apply_fn = model + else: + model_vars = model.init( + {"params": params_key, "dropout": dropout_key, "aqt": aqt_key}, + np.ones(input_shape, dtype=jnp.int32), + np.ones(input_shape, dtype=jnp.int32), + encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, + encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, + # nnx_method="no_op", + ) + apply_fn = model.apply + if is_training: - return init_training_state(model.apply, model_vars, tx) - return init_decode_state(model.apply, model_vars) + return init_training_state(apply_fn, model_vars, tx) + return init_decode_state(apply_fn, model_vars) def get_abstract_param(model, config): @@ -1060,14 +1119,13 @@ def get_abstract_param(model, config): return abstract_vars -def setup_decode_state(model, config, rng, mesh, checkpoint_manager): +def setup_decode_state(config, mesh, checkpoint_manager, init_state_fn): """Setup decode state by loading params from a checkpoint. Args: - model: the flax model to initialize config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: Checkpoint manager + init_state_fn: function to initialize the model state Returns: state: state with decode params loaded from the checkpoint @@ -1077,12 +1135,12 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): # generate random params max_logging.log("No decode checkpoint specified - generating random weights.") state, state_mesh_annotations, _, _ = setup_initial_state( - model, None, None, config, rng, mesh, checkpoint_manager, False + None, config, mesh, checkpoint_manager, init_state_fn, False ) else: # Load params from checkpoint max_logging.log(f"Loading decode params from {config.load_parameters_path}") - unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(model, None, config, rng, mesh, False) + unboxed_abstract_state, state_mesh_annotations, _ = get_abstract_state(config, mesh, init_state_fn, False) with nn_partitioning.axis_rules(config.logical_axis_rules): params = checkpointing.load_params_from_path( config.load_parameters_path, @@ -1097,49 +1155,44 @@ def setup_decode_state(model, config, rng, mesh, checkpoint_manager): return state, state_mesh_annotations -def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): +def setup_training_state(data_iterator, config, mesh, checkpoint_manager, init_state_fn): is_training = True return setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training, ) def setup_initial_state( - model, data_iterator, - tx, config, - rng, mesh, checkpoint_manager, + init_state_fn, is_training=True, ): """We initialize the model and optimizer state, and optionally load from a checkpoint as necessary. Args: - model: the flax model to initialize - tx: the optax.GradientTransformation + data_iterator: data iterator config: config object - rng: jax.prng key mesh: jax.devices() mesh checkpoint_manager: an Orbax checkpointing.CheckpointManager object + init_state_fn: function to initialize the training state is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( - model, tx, config, rng, mesh, is_training + config, mesh, init_state_fn, is_training ) # Initialization @@ -1173,40 +1226,75 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )(rng) + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() if raw_params: # If we loaded a partial state, we need to merge it. - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator -def get_logical_annotations(model, tx, config, rng, mesh, is_training=True): - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) +def _extract_partition_specs(abstract_state, config): + """Safely extracts PartitionSpecs, routing to the correct Flax API.""" + if not config.enable_nnx: + return nn.get_partition_spec(abstract_state) + + def get_nnx_spec(leaf): + if isinstance(leaf, nnx.VariableState): + return P(*leaf.sharding_names) if getattr(leaf, "sharding_names", None) else P() + elif hasattr(leaf, "sharding_names"): + return P(*leaf.sharding_names) + return P() + + return jax.tree_util.tree_map(get_nnx_spec, abstract_state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + + +def get_logical_annotations(config, mesh, init_state_fn): + init_state_partial = init_state_fn with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - logical_annotations = nn.get_partition_spec(abstract_state) + logical_annotations = _extract_partition_specs(abstract_state, config) return logical_annotations -def get_abstract_state(model, tx, config, rng, mesh, is_training=True): +def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" - init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training, rng) + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - state_logical_annotations = nn.get_partition_spec(abstract_state) + # Use the new dual-compatible extractor + state_logical_annotations = _extract_partition_specs(abstract_state, config) state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) if is_training and config.shard_optimizer_over_data: @@ -1244,6 +1332,148 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve + # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = P() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = P(*from_sharding_rules(out_sharding, rules)) + else: + pspec = P(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" @@ -1482,26 +1712,41 @@ def print_shardings_params(params, params_sharding, mesh, logical_annotations=No """ Print state shardings comparing Logical Definition vs Physical Result. """ - if not hasattr(params, "params"): - params = {"params": params} - if not hasattr(params_sharding, "params"): - params_sharding = {"params": params_sharding} - if logical_annotations and not hasattr(logical_annotations, "params"): - logical_annotations = {"params": logical_annotations} + if not isinstance(params, nnx.State): + if not hasattr(params, "params"): + params = {"params": params} + if not hasattr(params_sharding, "params"): + params_sharding = {"params": params_sharding} + if logical_annotations and not hasattr(logical_annotations, "params"): + logical_annotations = {"params": logical_annotations} leaves_params, _ = jax.tree_util.tree_flatten_with_path(params) leaves_sharding, _ = jax.tree_util.tree_flatten_with_path(params_sharding) - leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) - for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip(leaves_params, leaves_sharding, leaves_logical): - path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) - shape = jax.typeof(leaf_val) - pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) - pspec_str = str(tuple(pspec)) - logical_str = str(leaf_logical_val) + if logical_annotations is not None: + leaves_logical, _ = jax.tree_util.tree_flatten_with_path(logical_annotations) + for (path, leaf_val), (_, leaf_sharding), (_, leaf_logical_val) in zip( + leaves_params, leaves_sharding, leaves_logical + ): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) + logical_str = str(leaf_logical_val) + + message = ( + f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" + ) + max_logging.info(message) + else: + for (path, leaf_val), (_, leaf_sharding) in zip(leaves_params, leaves_sharding): + path_str = "/".join(str(p.key if hasattr(p, "key") else p.name) for p in path) + shape = jax.typeof(leaf_val) + pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh) + pspec_str = str(tuple(pspec)) - message = f" {path_str}\n" f" Shape: {shape}\n" f" Logical: {logical_str}\n" f" Physical: {pspec_str}" - max_logging.info(message) + message = f" {path_str}\n" f" Shape: {shape}\n" f" Physical: {pspec_str}" + max_logging.info(message) print(flush=True) @@ -1534,3 +1779,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs): delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging all_host_upload=False, # Only upload from lead host (Host 0) ) + + +def get_mesh_from_config( + config: pyconfig.HyperParameters, + devices: Sequence[jax.Device] | None = None, +) -> Mesh: + """ + Geh mesh from the configuration. + + Args: + config: the configuration + devices: the devices + + Returns: + the device mesh + """ + devices_array = create_device_mesh(config, devices) + + if config.shard_mode == ShardMode.EXPLICIT: + axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) + else: + axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) + + return Mesh(devices_array, config.mesh_axes, axis_types=axis_types) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py new file mode 100644 index 0000000000..7378928ef2 --- /dev/null +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -0,0 +1,172 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Utils for MaxText NNX. """ + +from functools import partial +from typing import Callable + +from flax import nnx +import jax +from jax.sharding import Mesh, NamedSharding + +from maxtext.utils import max_logging +from maxtext.configs import pyconfig + + +def create_nnx_rngs( + config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None +) -> nnx.Rngs: + """ + Create NNX Rngs + + Args: + config: the configuration + is_training: if the Rngs are for training + rng_key: the Rng key + + Returns: + The NNX Rngs + """ + if rng_key is None: + rng_key = jax.random.PRNGKey(config.init_weights_seed) + + if is_training: + return nnx.Rngs( + params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2) + ) + return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference + + +def get_named_sharding_nnx(abstract_state: nnx.State) -> nnx.State: + """Get named sharding from NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model. + + Returns: + named sharding structure + """ + # Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we + # get the existing sharding from the abstract_state. + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + return jax.tree.map( + lambda x: x.sharding, + abstract_state, + is_leaf=lambda x: isinstance(x, jax.ShapeDtypeStruct), + ) + + +def get_partition_spec_nnx(named_sharding: nnx.State) -> nnx.State: + """Get mesh partition spec from named sharding. + + Args: + named_sharding: NNX model named sharding. + + Returns: + mesh partition spec + """ + # The leaf is of type NamedSharding. + return jax.tree.map( + lambda x: x.spec, + named_sharding, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + + +def set_named_sharding_nnx(abstract_state: nnx.State, named_sharding: nnx.State) -> nnx.State: + """Set named sharding to NNX abstract state. + + Args: + abstract_state: NNX model abstract state created from nnx.get_abstract_model(). + named_sharding: named sharding. It must have the same tree structure with abstract_state. + + Returns: + updated abstract_state + """ + return jax.tree.map(lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), abstract_state, named_sharding) + + +def move_memory_to_host(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "pinned_host". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "pinned_host" + """ + max_logging.log(f"max_utils.py: Moving {path} to host") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="pinned_host") + + +def move_memory_to_device(path: tuple[str, ...], x: NamedSharding) -> NamedSharding: + """ + Change the memory_kind of the NamedSharding to "device". This function can be + called by jax.tree_util.tree_map_with_path on a NNX state structure. + + Args: + path: the tree path tuple + x: the NamedSharding corresponding to the path + + Returns: + the NamedSharding with memory_kind set to "device" + """ + max_logging.log(f"max_utils.py: Moving {path} to device") + # Create the new sharding with the target memory kind + return x.with_memory_kind(kind="device") + + +def create_nnx_sharded_model( + abstract_model: nnx.Module, + init_fn: Callable, + mesh: Mesh | None = None, + named_sharding: nnx.State | None = None, +) -> nnx.Module: + """ + Create the model with the given sharding. + + Args: + abstract_model: the abstract model + init_fn: the model init function + mesh: the device mesh + named_sharding: the given sharding + + Returns: + The initialized sharded model + """ + graphdef, abstract_state = nnx.split(abstract_model) + if named_sharding is None: + # The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding) + # we get the sharding directly from it. + named_sharding = get_named_sharding_nnx(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + @partial(jax.jit, out_shardings=named_sharding) + def create_sharded_state(): + model = init_fn() + return jax.lax.with_sharding_constraint(nnx.state(model), named_sharding) + + # Create the model with sharded parameters. + with jax.set_mesh(mesh): + sharded_state = create_sharded_state() + return nnx.merge(graphdef, sharded_state) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index f492744b24..0cc85a883f 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,22 +31,20 @@ import dataclasses from collections.abc import Sequence +from typing import Callable, overload from functools import partial -from typing import overload - from etils import epath from flax import nnx import flax.linen as nn import jax import jax.numpy as jnp -from jax.sharding import AxisType, Mesh +from jax.sharding import Mesh from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode +from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations from maxtext.models import models from maxtext.utils import max_logging -from maxtext.utils import max_utils -from maxtext.utils import maxtext_utils +from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx from orbax import checkpoint as ocp try: @@ -154,6 +166,7 @@ def from_config( mesh: Mesh | None = None, *, model_mode: str = MODEL_MODE_TRAIN, + rngs: None = None, ) -> nn.Module: ... @@ -194,15 +207,7 @@ def from_config( model = from_config(config) """ if mesh is None: - devices_array = maxtext_utils.create_device_mesh(config, devices) - - if config.shard_mode == ShardMode.EXPLICIT: - axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes)) - else: - axis_types = tuple([AxisType.Auto] * len(config.mesh_axes)) - - mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types) - + mesh = maxtext_utils.get_mesh_from_config(config, devices) model = create_model(config, mesh, model_mode=model_mode, rngs=rngs) # Return only the model @@ -226,21 +231,62 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" +def get_nnx_create_model_fn(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None) -> Callable: + + def _create_model(): + is_training = model_mode == MODEL_MODE_TRAIN + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - if rng_key is None: - rng_key = jax.random.PRNGKey(config.init_weights_seed) + return _create_model - if model_mode == MODEL_MODE_TRAIN: - rngs = nnx.Rngs(params=rng_key, dropout=1) - else: - rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) + Returns: + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. + """ + + with nn.logical_axis_rules(config.logical_axis_rules): + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) + + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. + + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) @@ -278,7 +324,25 @@ def create_sharded_state(): mesh=model.mesh, logical_annotations=specs, ) + return model + + +def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model(config, mesh, devices, model_mode, rng_key) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) + sharded_state = nnx.state(model) + + if mesh is None: + mesh = model.mesh + + with mesh: if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index f50acd269f..0c75f4abb5 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m MaxText.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,26 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +132,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,15 +179,21 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers @@ -172,4 +203,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 74b22548b0..dc508bb199 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -144,7 +145,8 @@ def maybe_shard_with_logical( def remove_size_one_mesh_axis(spec, mesh): """ - Removes mesh axes from a PartitionSpec (P) where the axis size is 1. + Removes mesh axes from a PartitionSpec (P) where the axis size is 1 + OR if the axis is not present in the mesh at all. This is a common optimization to simplify sharding by excluding redundant axes. Function originally from jax._src.core: @@ -157,15 +159,52 @@ def remove_size_one_mesh_axis(spec, mesh): if s is None or s == P.UNCONSTRAINED: new_spec.append(s) # type: ignore elif isinstance(s, tuple): - new_spec.append(tuple(i for i in s if mesh.shape.get(i, 1) != 1)) + # Filter for both existence and size > 1 + new_spec.append(tuple(i for i in s if i in mesh.axis_names and mesh.shape.get(i, 1) != 1)) else: - new_spec.append(None if mesh.shape.get(s, 1) == 1 else s) # type: ignore + # Replace with None if doesn't exist or size == 1 + new_spec.append(s if (s in mesh.axis_names and mesh.shape.get(s, 1) != 1) else None) # type: ignore return P(*new_spec, unreduced=spec.unreduced, reduced=spec.reduced) +def filter_rules_for_mesh(rules, mesh): + """Filters logical axis rules to remove physical axes that don't exist in the mesh.""" + if rules is None: + return None + new_rules = [] + for logical_name, physical_axes in rules: + if isinstance(physical_axes, str): + new_physical = physical_axes if physical_axes in mesh.axis_names else None + elif isinstance(physical_axes, (list, tuple)): + new_physical = tuple(ax for ax in physical_axes if ax in mesh.axis_names) + else: + new_physical = physical_axes + new_rules.append((logical_name, new_physical)) + return tuple(new_rules) + + def logical_to_mesh_axes(logical_names, mesh, rules=None): """Remove size one mesh axes given logical names.""" + # Filter rules before passing to nn.logical_to_mesh_axes + rules = filter_rules_for_mesh(rules, mesh) tensor_spec = nn.logical_to_mesh_axes(logical_names, rules=rules) + + # Strict axis existence check: ensure all axes in tensor_spec exist in the mesh. + # Rules might map to axes (like fsdp_transpose) that are missing from the current mesh. + if tensor_spec is not None: + new_spec = [] + for s in tensor_spec: + if s is None or s == P.UNCONSTRAINED: + new_spec.append(s) + elif isinstance(s, tuple): + # Filter tuple of axes for existence in mesh + valid_tuple = tuple(i for i in s if i in mesh.axis_names) + new_spec.append(valid_tuple if valid_tuple else None) + else: + # Check single axis for existence in mesh + new_spec.append(s if s in mesh.axis_names else None) + tensor_spec = P(*new_spec, unreduced=tensor_spec.unreduced, reduced=tensor_spec.reduced) + return remove_size_one_mesh_axis(tensor_spec, mesh) @@ -468,6 +507,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -486,6 +527,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index 1aaf800030..ba6b148b04 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -19,6 +19,7 @@ # See github.com/google/maxtext/issues/20 for more import datetime +from functools import partial import os from typing import Sequence @@ -51,11 +52,21 @@ def checkpoint_loop(config, state=None): Returns: """ - model = from_config(config) + if config.pure_nnx: + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = from_config(config) mesh = model.mesh - init_rng, checkpoint_manager, _, tx = train_utils.create_training_tools(config, model, mesh) - - unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + _, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) + + unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) # A barrier to sync all hosts before starting to restore checkpoint jax.experimental.multihost_utils.sync_global_devices("Barrier before load") checkpoint_load_start = datetime.datetime.now() @@ -81,10 +92,10 @@ def checkpoint_loop(config, state=None): "STANDALONE CHECKPOINTER : Checkpoint restored in :" f" {checkpoint_load_end - checkpoint_load_start}" ) else: # Checkpoint was unavailable, state needs to be initialized - state, _, _, _ = maxtext_utils.setup_training_state(model, None, tx, config, init_rng, mesh, checkpoint_manager) + state, _, _, _ = maxtext_utils.setup_training_state(None, config, mesh, checkpoint_manager, init_state_fn) state = add_entropy_to_checkpoint(state) - start_step = get_first_step(state) # this is the start_step for training + start_step = get_first_step(model, state) # this is the start_step for training for step in np.arange(start_step, config.steps): if checkpoint_manager is not None: start_time = datetime.datetime.now() diff --git a/src/maxtext/utils/standalone_dataloader.py b/src/maxtext/utils/standalone_dataloader.py index e8a942fa1e..ed77e61b35 100644 --- a/src/maxtext/utils/standalone_dataloader.py +++ b/src/maxtext/utils/standalone_dataloader.py @@ -38,13 +38,13 @@ def data_load_loop(config, state=None): """Main data loader loop. Loads batches of data for each training step. """ - _, _, _, _, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) + _, _, _, model, mesh, _, data_iterator, _, _, _, state = setup_train_loop(config, recorder=None) data_loader = DataLoader(config, mesh, data_iterator, None) example_batch = None start = datetime.datetime.now() - start_step = get_first_step(state) + start_step = get_first_step(model, state) example_batch = data_loader.load_next_batch() jax.block_until_ready(example_batch) first_end = datetime.datetime.now() diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 54b2755801..2a55c8b6b1 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,10 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os +from functools import partial + import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -33,12 +37,17 @@ from maxtext.trainers.diloco import diloco -def create_training_tools(config, model, mesh): - """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" - init_rng = jax.random.PRNGKey(config.init_weights_seed) +def create_training_optimizer(config, model): + """Creates the optimizer and learning rate schedule.""" learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon tx = optimizers.get_optimizer(config, learning_rate_schedule, model) + return learning_rate_schedule, tx + + +def create_checkpoint_manager(config, mesh, init_state_fn): + """Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager.""" + # pass in model for muon logger = checkpointing.setup_checkpoint_logger(config) if config.enable_multi_tier_checkpointing: checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager( @@ -47,7 +56,7 @@ def create_training_tools(config, model, mesh): mesh, ) elif config.enable_emergency_checkpoint: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager( config.local_checkpoint_directory, config.checkpoint_dir, @@ -85,10 +94,10 @@ def create_training_tools(config, model, mesh): config.enable_autocheckpoint, ) - return init_rng, checkpoint_manager, learning_rate_schedule, tx + return checkpoint_manager -def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings): +def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=None): """Returns a JIT-compiled train step function, which is loaded from a file if specified in the config.""" if config.enable_diloco: functional_train = train_step @@ -110,7 +119,9 @@ def jit_train_step(config, model, state, state_mesh_shardings, data_sharding, tr # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit if config.compiled_trainstep_file != "": max_logging.log("Loading the compiled function...") - execution_devices = model.mesh.devices.flatten().tolist() + # For NNX, model is the GraphDef (no .mesh); use the mesh passed explicitly instead. + execution_mesh = mesh if mesh is not None else model.mesh + execution_devices = execution_mesh.devices.flatten().tolist() # Need to pass train signature and state to determine i/o shapes of train_state for now. p_train_step = maxtext_utils.load_compiled(config, functional_train, state, execution_devices) max_logging.log("Loaded compiled function!") @@ -165,7 +176,9 @@ def jit_train_and_eval_step( train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) data_sharding = sharding.get_input_data_sharding(config, mesh) - p_train_step = jit_train_step(config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings) + p_train_step = jit_train_step( + config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh + ) p_eval_step = None if eval_data_iterator: p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) @@ -191,15 +204,33 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): - model = model_creation_utils.from_config(config, devices) - mesh = model.mesh - init_rng, checkpoint_manager, learning_rate_schedule, tx = create_training_tools(config, model, mesh) + is_training = True + init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) + if config.pure_nnx: + # Create abstract NNX model. + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) + else: + model = model_creation_utils.from_config(config, devices) + learning_rate_schedule, tx = create_training_optimizer(config, model) + + if config.pure_nnx: + # For NNX, the train state is wrapped in the TrainStateNNX module. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + init_state_fn = create_train_state_fn + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) + checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): data_iterator, eval_data_iterator = create_data_iterator(config, mesh) @@ -225,8 +256,17 @@ def setup_train_loop(config, recorder, devices=None): ) state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( - model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -244,18 +284,25 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(model, tx, config, init_rng, mesh, is_training=True) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) + if config.pure_nnx: + raise NotImplementedError("DPO is not supported yet by NNX models.") + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) @@ -279,12 +326,18 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: + # TODO: For pure_nnx, the dpo state manipulation is different. reference_params = step0_restored["items"].params["params"] state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -297,7 +350,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/tests/assets/logits_generation/generate_grpo_golden_logits.py b/tests/assets/logits_generation/generate_grpo_golden_logits.py index e4e9f4fe8a..cae8b9e4d3 100644 --- a/tests/assets/logits_generation/generate_grpo_golden_logits.py +++ b/tests/assets/logits_generation/generate_grpo_golden_logits.py @@ -38,7 +38,7 @@ from maxtext.inference.maxengine import maxengine from maxtext.models import models from maxtext.utils import maxtext_utils -from tests.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs +from tests.post_training.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs import numpy as np import torch import transformers @@ -73,17 +73,27 @@ def setUp(self): devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) # With checkpoint - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules) self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # Without checkpoint - self.model_no_ckpt_loading = models.transformer_as_linen( - config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN - ) - self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state( - self.model_no_ckpt_loading, self.cfg_no_ckpt_loading, self.rng, mesh, None - ) + if self.cfg_no_ckpt_loading.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model_no_ckpt_loading = models.transformer_as_linen( + config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial( + maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng + ) + self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", diff --git a/tests/integration/decode_tests.py b/tests/integration/decode_tests.py index f36ecf9efd..163370d087 100644 --- a/tests/integration/decode_tests.py +++ b/tests/integration/decode_tests.py @@ -36,6 +36,8 @@ class DecodeTests(unittest.TestCase): _base_output_directory = get_test_base_output_directory() GEMMA_2B_CKPT_PATH = "gs://maxtext-gemma/2b/2025-11-04-04-33//0/items" + # Decode/inference uses maxengine which does not yet support NNX; use Linen. + _LINEN_FLAGS = ["pure_nnx=False", "enable_nnx=False", "pure_nnx_decoder=False"] CONFIGS = { "base": [ # tests decode None, @@ -49,7 +51,8 @@ class DecodeTests(unittest.TestCase): "max_target_length=128", "per_device_batch_size=1", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", - ], + ] + + _LINEN_FLAGS, "int8": [ # tests decode with int8 quantization None, get_test_config_path(), @@ -64,7 +67,8 @@ class DecodeTests(unittest.TestCase): "quantization=int8", "quantize_kvcache=True", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", - ], + ] + + _LINEN_FLAGS, "pdb_lt_1": [ # tests decode with per_device_batch_size < 1 None, get_test_config_path(), @@ -77,7 +81,8 @@ class DecodeTests(unittest.TestCase): "max_target_length=128", "per_device_batch_size=.25", rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", - ], + ] + + _LINEN_FLAGS, "decode_sampling": [ None, get_test_config_path(), @@ -95,7 +100,8 @@ class DecodeTests(unittest.TestCase): "attention=dot_product", "prompt=I love to", "skip_jax_distributed_system=True", - ], + ] + + _LINEN_FLAGS, } SAMPLING_STRATEGY_CONFIG = { "greedy": [ diff --git a/tests/integration/generate_param_only_checkpoint_test.py b/tests/integration/generate_param_only_checkpoint_test.py index c44831f5d5..94ebebcea1 100644 --- a/tests/integration/generate_param_only_checkpoint_test.py +++ b/tests/integration/generate_param_only_checkpoint_test.py @@ -54,6 +54,9 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta f"attention={attention_type}", "max_target_length=128", "per_device_batch_size=1", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] + model_config pathways_command = [] @@ -72,6 +75,11 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta dataset_type="tfds", dataset_path=dataset_path, ) + + [ + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", + ] ) state_path = f"{base_output_directory}/runner_{run_date}/checkpoints/0/items" diff --git a/tests/integration/gradient_accumulation_test.py b/tests/integration/gradient_accumulation_test.py index 28523d9dc1..473dda2ead 100644 --- a/tests/integration/gradient_accumulation_test.py +++ b/tests/integration/gradient_accumulation_test.py @@ -28,7 +28,6 @@ from maxtext.common.gcloud_stub import is_decoupled from maxtext.trainers.pre_train.train import main as train_main from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT -from maxtext.trainers.post_train.sft.train_sft_deprecated import main as sft_main from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory @@ -150,6 +149,9 @@ def test_grad_accumulate_same_loss(self): @pytest.mark.integration_test @pytest.mark.tpu_only def test_sft_grad_accumulate_same_loss(self): + pytest.importorskip("tunix") + from maxtext.trainers.post_train.sft.train_sft import main as sft_main # pylint: disable=import-outside-toplevel + sft_main( [ None, @@ -164,6 +166,5 @@ def test_sft_grad_accumulate_same_loss(self): rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", "steps=3", "gradient_accumulation_steps=2", - "use_sft=True", ] ) diff --git a/tests/integration/smoke/inference_microbenchmark_smoke_test.py b/tests/integration/smoke/inference_microbenchmark_smoke_test.py index 4113f51df9..fa97dd3d3a 100644 --- a/tests/integration/smoke/inference_microbenchmark_smoke_test.py +++ b/tests/integration/smoke/inference_microbenchmark_smoke_test.py @@ -53,6 +53,9 @@ def test(self): "weight_dtype=bfloat16", "attention=dot_product", "skip_jax_distributed_system=True", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] ) run_benchmarks(config) diff --git a/tests/post_training/integration/grpo_correctness.py b/tests/post_training/integration/grpo_correctness.py index 44a3e28df7..adefc03a7e 100644 --- a/tests/post_training/integration/grpo_correctness.py +++ b/tests/post_training/integration/grpo_correctness.py @@ -13,6 +13,7 @@ # limitations under the License. """GRPO correctness tests""" +import functools import os import unittest @@ -60,8 +61,13 @@ def setUp(self): self.rng = jax.random.PRNGKey(42) devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) - self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) + if self.cfg.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng) + self.state, _ = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", add_bos_token=False, @@ -121,7 +127,7 @@ def _prepare_maxtext_inputs(self): ) def _prepare_trl_inputs(self): - """Prepare TRL inputs.""" + """Prepare inputs for TRL model.""" tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt") input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1) attention_mask = torch.cat( diff --git a/tests/post_training/integration/grpo_trainer_correctness_test.py b/tests/post_training/integration/grpo_trainer_correctness_test.py index 9a2cfd4078..b880a0e678 100644 --- a/tests/post_training/integration/grpo_trainer_correctness_test.py +++ b/tests/post_training/integration/grpo_trainer_correctness_test.py @@ -25,6 +25,7 @@ pytest tests/post_training/integration/grpo_trainer_correctness_test.py """ +import functools import os import subprocess import sys @@ -72,8 +73,13 @@ def setup_maxtext_model(config, mesh): init_rng = jax.random.PRNGKey(config.init_weights_seed) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, state_mesh_annotations = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, state_mesh_annotations = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, config.logical_axis_rules) data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) reference_params = jax.tree.map(jnp.copy, state.params["params"]) diff --git a/tests/post_training/integration/sft_trainer_correctness_test.py b/tests/post_training/integration/sft_trainer_correctness_test.py index beeb2036d9..89ac19d0f3 100644 --- a/tests/post_training/integration/sft_trainer_correctness_test.py +++ b/tests/post_training/integration/sft_trainer_correctness_test.py @@ -24,6 +24,7 @@ pytest tests/post_training/integration/sft_trainer_correctness_test.py """ +import functools import os.path import subprocess import sys @@ -117,8 +118,13 @@ def setup_maxtext_model(config): quant = quantizations.configure_quantization(config) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, init_rng, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config=config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, init_rng) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) return maxtext_model, state, init_rng diff --git a/tests/unit/compare_linen_nnx_checkpoint_test.py b/tests/unit/compare_linen_nnx_checkpoint_test.py new file mode 100644 index 0000000000..d3d49e6a63 --- /dev/null +++ b/tests/unit/compare_linen_nnx_checkpoint_test.py @@ -0,0 +1,501 @@ +# Copyright 2023-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for compare_linen_nnx_checkpoint utilities.""" + +import io +import unittest +from unittest.mock import patch +import numpy as np + +from absl import flags as absl_flags +from maxtext.checkpoint_conversion.compare_linen_nnx_checkpoint import ( + is_rng_path, + filter_rngs, + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _normalize_linen_params, + _normalize_nnx_params, + _extract_params, + _normalize_params, + get_tree_structure_info, + print_structure_diff, + compare_params, + transform_nnx_params_for_comparison, +) + + +def _arr(*shape): + """Helper: float32 array of given shape, values 0..prod(shape)-1.""" + return np.arange(int(np.prod(shape)), dtype=np.float32).reshape(shape) + + +def setUpModule(): + # Mark FLAGS as parsed so FLAGS.verbose etc. are accessible without a full + # app.run(). Required flags (ckpt_path_1/2) are not needed in unit tests. + absl_flags.FLAGS.mark_as_parsed() + + +# --------------------------------------------------------------------------- +# is_rng_path +# --------------------------------------------------------------------------- + + +class TestIsRngPath(unittest.TestCase): + """Tests for is_rng_path.""" + + def test_returns_true_for_rngs(self): + self.assertTrue(is_rng_path("model/decoder/rngs/dropout")) + + def test_returns_true_for_rng(self): + self.assertTrue(is_rng_path("model/rngs/params/key")) + + def test_returns_true_case_insensitive(self): + self.assertTrue(is_rng_path("model/RNGs/state")) + self.assertTrue(is_rng_path("model/RNG/state")) + + def test_returns_false_for_normal_path(self): + self.assertFalse(is_rng_path("model/decoder/layers/kernel")) + + def test_returns_false_for_empty_string(self): + self.assertFalse(is_rng_path("")) + + +# --------------------------------------------------------------------------- +# filter_rngs +# --------------------------------------------------------------------------- + + +class TestFilterRngs(unittest.TestCase): + """Tests for filter_rngs.""" + + def test_removes_top_level_rngs_key(self): + tree = {"model": {"kernel": _arr(4)}, "rngs": {"dropout": _arr(2)}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result) + self.assertIn("model", result) + + def test_removes_nested_rngs_key(self): + tree = {"model": {"kernel": _arr(4), "rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("rngs", result["model"]) + self.assertIn("kernel", result["model"]) + + def test_keeps_empty_parent_when_only_child_is_rng(self): + # After filtering, the parent dict becomes empty and is dropped. + tree = {"model": {"rngs": {"key": _arr(2)}}} + result = filter_rngs(tree) + self.assertNotIn("model", result) + + def test_passthrough_for_non_rng_tree(self): + tree = {"params": {"kernel": _arr(4), "bias": _arr(2)}} + result = filter_rngs(tree) + self.assertEqual(set(result.keys()), {"params"}) + + def test_passthrough_for_non_dict_input(self): + arr = _arr(4) + self.assertIs(filter_rngs(arr), arr) + + +# --------------------------------------------------------------------------- +# _has_value_wrappers +# --------------------------------------------------------------------------- + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for _has_value_wrappers.""" + + def test_returns_true_for_direct_value_wrapper(self): + tree = {"value": _arr(3, 4)} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_true_for_nested_wrapper(self): + tree = {"decoder": {"kernel": {"value": _arr(2, 2)}}} + self.assertTrue(_has_value_wrappers(tree)) + + def test_returns_false_for_plain_array(self): + self.assertFalse(_has_value_wrappers(_arr(3))) + + def test_returns_false_for_multi_key_dict(self): + tree = {"value": _arr(2), "extra": _arr(2)} + self.assertFalse(_has_value_wrappers(tree)) + + def test_returns_false_for_value_key_with_non_array(self): + tree = {"value": 42} + self.assertFalse(_has_value_wrappers(tree)) + + +# --------------------------------------------------------------------------- +# _strip_value_wrappers +# --------------------------------------------------------------------------- + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for _strip_value_wrappers.""" + + def test_strips_direct_wrapper(self): + arr = _arr(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _arr(2, 2) + tree = {"decoder": {"kernel": {"value": arr}}} + result = _strip_value_wrappers(tree) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_array(self): + arr = _arr(4) + self.assertIs(_strip_value_wrappers(arr), arr) + + def test_handles_list(self): + arr = _arr(2) + result = _strip_value_wrappers([{"value": arr}]) + np.testing.assert_array_equal(result[0], arr) + + def test_handles_tuple(self): + arr = _arr(2) + result = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result[0], arr) + + def test_passthrough_non_array_scalar(self): + self.assertEqual(_strip_value_wrappers(42), 42) + + +# --------------------------------------------------------------------------- +# _normalize_linen_params +# --------------------------------------------------------------------------- + + +class TestNormalizeLinenParams(unittest.TestCase): + """Tests for _normalize_linen_params.""" + + def test_removes_double_nesting(self): + inner = {"decoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_removes_double_nesting_encoder(self): + inner = {"encoder": {"layers": {}}} + params = {"params": inner} + result = _normalize_linen_params(params) + self.assertIs(result, inner) + + def test_passthrough_when_no_double_nesting(self): + params = {"decoder": {"layers": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + def test_passthrough_when_inner_has_no_decoder_encoder(self): + params = {"params": {"other_key": {}}} + result = _normalize_linen_params(params) + self.assertIs(result, params) + + +# --------------------------------------------------------------------------- +# _normalize_nnx_params +# --------------------------------------------------------------------------- + + +class TestNormalizeNnxParams(unittest.TestCase): + """Tests for _normalize_nnx_params.""" + + def test_strips_value_wrappers(self): + arr = _arr(2, 3) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + def test_passthrough_plain_tree(self): + arr = _arr(4) + params = {"decoder": {"kernel": arr}} + result = _normalize_nnx_params(params) + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# detect_format +# --------------------------------------------------------------------------- + + +class TestDetectFormat(unittest.TestCase): + """Tests for detect_format.""" + + def test_detects_nnx_via_model_key(self): + state = {"model": {"decoder": {}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_double_nested_decoder(self): + state = {"params": {"params": {"decoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_linen_via_double_nested_encoder(self): + state = {"params": {"params": {"encoder": {}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_value_wrappers(self): + arr = _arr(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_raises_when_no_params_or_model_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_raises_on_undetectable_format(self): + with self.assertRaises(ValueError): + detect_format({"params": {"unknown_key": {}}}) + + +# --------------------------------------------------------------------------- +# _extract_params +# --------------------------------------------------------------------------- + + +class TestExtractParams(unittest.TestCase): + """Tests for _extract_params.""" + + def test_extracts_linen_params(self): + params = {"params": {"decoder": {}}} + state = {"params": params, "opt_state": {}} + self.assertIs(_extract_params(state, "linen"), params) + + def test_extracts_nnx_params_from_model_key(self): + model = {"decoder": {}} + state = {"model": model, "optimizer": {}} + self.assertIs(_extract_params(state, "nnx"), model) + + def test_extracts_nnx_params_falls_back_to_params_key(self): + params = {"decoder": {}} + state = {"params": params} + self.assertIs(_extract_params(state, "nnx"), params) + + def test_returns_empty_dict_when_key_missing(self): + state = {"optimizer": {}} + result = _extract_params(state, "linen") + self.assertEqual(result, {}) + + +# --------------------------------------------------------------------------- +# _normalize_params +# --------------------------------------------------------------------------- + + +class TestNormalizeParams(unittest.TestCase): + """Tests for _normalize_params.""" + + def test_dispatches_to_linen(self): + inner = {"decoder": {}} + params = {"params": inner} + result = _normalize_params(params, "linen") + self.assertIs(result, inner) + + def test_dispatches_to_nnx(self): + arr = _arr(2, 2) + params = {"decoder": {"kernel": {"value": arr}}} + result = _normalize_params(params, "nnx") + np.testing.assert_array_equal(result["decoder"]["kernel"], arr) + + +# --------------------------------------------------------------------------- +# get_tree_structure_info +# --------------------------------------------------------------------------- + + +class TestGetTreeStructureInfo(unittest.TestCase): + """Tests for get_tree_structure_info.""" + + def test_returns_shape_and_dtype(self): + tree = {"kernel": _arr(3, 4), "bias": _arr(4)} + info = get_tree_structure_info(tree) + self.assertEqual(info["['kernel']"], ((3, 4), "float32")) + self.assertEqual(info["['bias']"], ((4,), "float32")) + + def test_handles_nested_tree(self): + tree = {"decoder": {"kernel": _arr(2, 2)}} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shapes = [v[0] for v in info.values()] + self.assertIn((2, 2), shapes) + + def test_handles_non_array_leaves(self): + tree = {"step": 5} + info = get_tree_structure_info(tree) + self.assertEqual(len(info), 1) + shape, _ = list(info.values())[0] + self.assertEqual(shape, "N/A") + + +# --------------------------------------------------------------------------- +# print_structure_diff +# --------------------------------------------------------------------------- + + +class TestPrintStructureDiff(unittest.TestCase): + """Tests for print_structure_diff.""" + + def _make_params(self, keys_and_shapes): + return {k: _arr(*s) for k, s in keys_and_shapes.items()} + + def test_returns_empty_tuples_when_identical(self): + params = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, shape_mm, dtype_mm = print_structure_diff(params, params) + self.assertEqual(only1, []) + self.assertEqual(only2, []) + self.assertEqual(shape_mm, []) + self.assertEqual(dtype_mm, []) + + def test_detects_key_only_in_first(self): + p1 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + p2 = self._make_params({"kernel": (4, 4)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(len(only1), 1) + self.assertEqual(only2, []) + + def test_detects_key_only_in_second(self): + p1 = self._make_params({"kernel": (4, 4)}) + p2 = self._make_params({"kernel": (4, 4), "bias": (4,)}) + with patch("sys.stdout", new_callable=io.StringIO): + only1, only2, _, _ = print_structure_diff(p1, p2) + self.assertEqual(only1, []) + self.assertEqual(len(only2), 1) + + def test_detects_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, shape_mm, _ = print_structure_diff(p1, p2) + self.assertEqual(len(shape_mm), 1) + + def test_detects_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("sys.stdout", new_callable=io.StringIO): + _, _, _, dtype_mm = print_structure_diff(p1, p2) + self.assertEqual(len(dtype_mm), 1) + + +# --------------------------------------------------------------------------- +# compare_params +# --------------------------------------------------------------------------- + + +class TestCompareParams(unittest.TestCase): + """Tests for compare_params.""" + + def test_returns_true_for_identical_params(self): + params = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + def test_returns_false_for_different_structures(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 4), "bias": _arr(4)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_shape_mismatch(self): + p1 = {"kernel": _arr(4, 4)} + p2 = {"kernel": _arr(4, 8)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_returns_false_for_dtype_mismatch(self): + p1 = {"kernel": np.zeros((4,), dtype=np.float32)} + p2 = {"kernel": np.zeros((4,), dtype=np.float16)} + with patch("builtins.print"): + result = compare_params(p1, p2) + self.assertFalse(result) + + def test_value_comparison_passes_when_equal(self): + arr = _arr(4) + with patch("builtins.print"): + result = compare_params({"w": arr}, {"w": arr.copy()}, compare_values=True) + self.assertTrue(result) + + def test_value_comparison_fails_when_different(self): + p1 = {"w": np.array([1.0, 2.0], dtype=np.float32)} + p2 = {"w": np.array([1.0, 9.0], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertFalse(result) + + def test_value_comparison_passes_within_tolerance(self): + p1 = {"w": np.array([1.0], dtype=np.float32)} + p2 = {"w": np.array([1.0 + 1e-7], dtype=np.float32)} + with patch("builtins.print"): + result = compare_params(p1, p2, compare_values=True, atol=1e-5, rtol=1e-5) + self.assertTrue(result) + + def test_verbose_mode_does_not_raise(self): + params = {"kernel": _arr(2, 2)} + with patch("builtins.print"): + result = compare_params(params, params, verbose=True, compare_values=True) + self.assertTrue(result) + + def test_nested_params(self): + params = {"decoder": {"kernel": _arr(4, 4), "bias": _arr(4)}} + with patch("builtins.print"): + result = compare_params(params, params) + self.assertTrue(result) + + +# --------------------------------------------------------------------------- +# transform_nnx_params_for_comparison +# --------------------------------------------------------------------------- + + +class TestTransformNnxParamsForComparison(unittest.TestCase): + """Tests for transform_nnx_params_for_comparison.""" + + def test_transposes_layer_array(self): + # Shape (num_layers=3, d=4) -> (d=4, num_layers=3) + arr = _arr(3, 4) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (4, 3)) + + def test_does_not_transpose_non_layer_array(self): + arr = _arr(3, 4) + tree = {"embedding": arr} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["embedding"].shape, (3, 4)) + + def test_does_not_transpose_1d_layer_array(self): + arr = _arr(4) + tree = {"layers": {"bias": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["bias"].shape, (4,)) + + def test_transposes_higher_rank_layer_array(self): + # Shape (num_layers=2, d1=3, d2=5) -> (d1=3, num_layers=2, d2=5) + arr = _arr(2, 3, 5) + tree = {"layers": {"kernel": arr}} + with patch("builtins.print"): + result = transform_nnx_params_for_comparison(tree) + self.assertEqual(result["layers"]["kernel"].shape, (3, 2, 5)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/diloco_test.py b/tests/unit/diloco_test.py index 177fbac98a..58e3c5cafb 100644 --- a/tests/unit/diloco_test.py +++ b/tests/unit/diloco_test.py @@ -77,6 +77,10 @@ def test_diloco_training_simulation_with_mesh(self): f"diloco_sync_period={num_steps-1}", ] ) + if test_config.pure_nnx: + self.skipTest( + "test_diloco_training_simulation_with_mesh uses a hand-crafted Linen TrainState; NNX path not yet covered." + ) with mesh: tx = optax.sgd(learning_rate=0.1) diff --git a/tests/unit/linen_nnx_converter_test.py b/tests/unit/linen_nnx_converter_test.py new file mode 100644 index 0000000000..808990f8cf --- /dev/null +++ b/tests/unit/linen_nnx_converter_test.py @@ -0,0 +1,869 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linen_nnx_converter utilities.""" + +import unittest +import numpy as np +from unittest.mock import MagicMock, patch + +from maxtext.checkpoint_conversion.linen_nnx_converter import ( + detect_format, + _has_value_wrappers, + _strip_value_wrappers, + _add_value_wrappers, + _transpose_layers_axes, + _stack_layers, + convert_linen_to_nnx, + convert_nnx_to_linen, + _convert_opt_state_linen_to_nnx, + _convert_opt_state_nnx_to_linen, + load_checkpoint, + save_checkpoint, + main, +) + + +def _make_array(*shape): + """Helper to create a numpy array with given shape.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + +class TestDetectFormat(unittest.TestCase): + """Tests for the detect_format function.""" + + def test_raises_when_no_params_key(self): + with self.assertRaises(ValueError): + detect_format({"step": 0}) + + def test_detects_nnx_format_via_model_key(self): + # NNX: top-level "model" key + state = {"model": {"decoder": {"layers": {}}}, "optimizer": {}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_format_double_nested(self): + state = {"params": {"params": {"decoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_format_single_nested_with_value_wrappers(self): + # Old NNX format: params/decoder with {value:} wrappers + arr = _make_array(2, 2) + state = {"params": {"decoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_encoder(self): + state = {"params": {"params": {"encoder": {"layers": {}}}}} + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_encoder_with_value_wrappers(self): + arr = _make_array(2, 2) + state = {"params": {"encoder": {"kernel": {"value": arr}}}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_nnx_via_optimizer_key(self): + arr = _make_array(2, 2) + state = {"params": {"something": arr}, "optimizer": {"step": 0}} + self.assertEqual(detect_format(state), "nnx") + + def test_detects_linen_via_opt_state(self): + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "opt_state": {"params": {"mu": {"decoder": {"kernel": arr}}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_detects_nnx_via_optimizer_over_opt_state(self): + # "optimizer" key takes precedence for NNX detection + arr = _make_array(2, 2) + state = { + "params": {"something": arr}, + "optimizer": {"step": 0, "opt_state": {}}, + } + self.assertEqual(detect_format(state), "nnx") + + def test_raises_on_undetectable_format(self): + state = {"params": {"some_unknown_key": 42}} + with self.assertRaises(ValueError): + detect_format(state) + + +class TestHasValueWrappers(unittest.TestCase): + """Tests for the _has_value_wrappers helper.""" + + def test_returns_true_for_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"value": arr})) + + def test_returns_true_for_nested_value_wrapper(self): + arr = _make_array(2, 2) + self.assertTrue(_has_value_wrappers({"mu": {"value": arr}})) + + def test_returns_false_for_plain_array(self): + # A plain array is not a {"value": ...} wrapper dict + self.assertFalse(_has_value_wrappers(_make_array(2, 2))) + + def test_returns_false_for_multi_key_dict(self): + arr = _make_array(2, 2) + self.assertFalse(_has_value_wrappers({"value": arr, "extra": arr})) + + def test_returns_false_for_non_array_value(self): + self.assertFalse(_has_value_wrappers({"value": "string"})) + + +class TestStripValueWrappers(unittest.TestCase): + """Tests for the _strip_value_wrappers helper.""" + + def test_strips_single_wrapper(self): + arr = _make_array(3, 4) + result = _strip_value_wrappers({"value": arr}) + np.testing.assert_array_equal(result, arr) + + def test_strips_nested_wrappers(self): + arr = _make_array(2, 2) + wrapped = {"decoder": {"layers": {"kernel": {"value": arr}}}} + stripped = _strip_value_wrappers(wrapped) + np.testing.assert_array_equal(stripped["decoder"]["layers"]["kernel"], arr) + + def test_passes_through_plain_array(self): + arr = _make_array(2, 3) + result = _strip_value_wrappers(arr) + np.testing.assert_array_equal(result, arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _strip_value_wrappers([{"value": arr}]) + result_tuple = _strip_value_wrappers(({"value": arr},)) + np.testing.assert_array_equal(result_list[0], arr) + np.testing.assert_array_equal(result_tuple[0], arr) + + def test_passes_through_non_array_value(self): + # A dict with key "value" but scalar content should not be unwrapped + d = {"value": 42} + result = _strip_value_wrappers(d) + self.assertEqual(result, d) + + +class TestAddValueWrappers(unittest.TestCase): + """Tests for the _add_value_wrappers helper.""" + + def test_wraps_array(self): + arr = _make_array(3, 4) + result = _add_value_wrappers(arr) + self.assertIsInstance(result, dict) + self.assertIn("value", result) + np.testing.assert_array_equal(result["value"], arr) + + def test_wraps_nested_arrays(self): + arr = _make_array(2, 2) + nested = {"decoder": {"layers": {"kernel": arr}}} + wrapped = _add_value_wrappers(nested) + self.assertEqual(set(wrapped["decoder"]["layers"]["kernel"].keys()), {"value"}) + np.testing.assert_array_equal(wrapped["decoder"]["layers"]["kernel"]["value"], arr) + + def test_idempotent_on_already_wrapped(self): + arr = _make_array(2) + already_wrapped = {"value": arr} + result = _add_value_wrappers(already_wrapped) + # Should not double-wrap + self.assertEqual(set(result.keys()), {"value"}) + np.testing.assert_array_equal(result["value"], arr) + + def test_handles_list_and_tuple(self): + arr = _make_array(2) + result_list = _add_value_wrappers([arr]) + result_tuple = _add_value_wrappers((arr,)) + self.assertEqual(set(result_list[0].keys()), {"value"}) + self.assertEqual(set(result_tuple[0].keys()), {"value"}) + + def test_passes_through_non_array_scalars(self): + result = _add_value_wrappers(42) + self.assertEqual(result, 42) + result_str = _add_value_wrappers("text") + self.assertEqual(result_str, "text") + + +class TestTransposeLayersAxes(unittest.TestCase): + """Tests for the _transpose_layers_axes helper.""" + + def test_noop_when_same_axis(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=0) + np.testing.assert_array_equal(result, arr) + + def test_transposes_axis_0_to_1(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + self.assertEqual(result.shape, (2, 4, 3)) + + def test_transposes_axis_1_to_0(self): + arr = _make_array(2, 4, 3) + result = _transpose_layers_axes(arr, src_axis=1, dst_axis=0) + self.assertEqual(result.shape, (4, 2, 3)) + + def test_transposes_nested_dict(self): + arr = _make_array(4, 2, 3) + tree = {"decoder": {"layers": {"kernel": arr}}} + result = _transpose_layers_axes(tree, src_axis=0, dst_axis=1) + self.assertEqual(result["decoder"]["layers"]["kernel"].shape, (2, 4, 3)) + + def test_passes_through_1d_array(self): + arr = _make_array(5) + result = _transpose_layers_axes(arr, src_axis=0, dst_axis=1) + # 1D array has no axis 1, should be returned unchanged + np.testing.assert_array_equal(result, arr) + + def test_handles_list(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes([arr], src_axis=0, dst_axis=1) + self.assertIsInstance(result, list) + self.assertEqual(result[0].shape, (2, 4, 3)) + + def test_handles_tuple(self): + arr = _make_array(4, 2, 3) + result = _transpose_layers_axes((arr,), src_axis=0, dst_axis=1) + self.assertIsInstance(result, tuple) + self.assertEqual(result[0].shape, (2, 4, 3)) + + +class TestStackLayers(unittest.TestCase): + """Tests for the _stack_layers helper.""" + + def test_stacks_individual_layers(self): + arr0 = _make_array(3, 4) + arr1 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "layers_1": {"mlp": {"kernel": arr1}}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + stacked = result["layers"]["mlp"]["kernel"] + self.assertEqual(stacked.shape, (2, 3, 4)) + np.testing.assert_array_equal(stacked[0], arr0) + np.testing.assert_array_equal(stacked[1], arr1) + + def test_noop_when_no_layer_pattern(self): + arr = _make_array(3, 4) + decoder = {"layers": {"mlp": {"kernel": arr}}} + result, was_stacked = _stack_layers(decoder) + self.assertFalse(was_stacked) + self.assertIs(result, decoder) + + def test_preserves_non_layer_keys(self): + norm_weight = _make_array(4) + arr0 = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr0}}, + "final_norm": {"scale": norm_weight}, + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("final_norm", result) + np.testing.assert_array_equal(result["final_norm"]["scale"], norm_weight) + + def test_stacks_three_layers(self): + arrays = [_make_array(2, 2) for _ in range(3)] + decoder = {f"layers_{i}": {"w": arrays[i]} for i in range(3)} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + stacked = result["layers"]["w"] + self.assertEqual(stacked.shape, (3, 2, 2)) + + def test_non_array_non_dict_leaf(self): + # Scalar leaf — stack_arrays returns first element + decoder = {"layers_0": {"count": 1}, "layers_1": {"count": 2}} + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("layers", result) + + def test_with_missing_key_in_some_layers(self): + arr = _make_array(3, 4) + decoder = { + "layers_0": {"mlp": {"kernel": arr, "bias": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, # no "bias" + } + result, was_stacked = _stack_layers(decoder) + self.assertTrue(was_stacked) + self.assertIn("kernel", result["layers"]["mlp"]) + + +class TestConvertLinenToNNX(unittest.TestCase): + """Tests for the convert_linen_to_nnx function.""" + + def _make_linen_state(self, add_opt_state=False): + """Creates a minimal Linen checkpoint structure.""" + arr = _make_array(2, 4, 3) + state = { + "step": 10, + "params": { + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + "decoder_norm": {"scale": _make_array(4)}, + } + } + }, + } + if add_opt_state: + state["opt_state"] = {"params": {"mu": {"decoder": {"layers": {"kernel": arr}}}}} + return state + + def test_converts_step_under_optimizer(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 10) + + def test_step_not_at_top_level(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + + def test_params_stored_under_model_key(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + self.assertIn("model", result) + self.assertNotIn("params", result) + + def test_removes_double_nesting(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # model should have 'decoder' directly, not 'params.decoder' + self.assertIn("decoder", result["model"]) + self.assertNotIn("params", result["model"]) + + def test_adds_value_wrappers(self): + state = self._make_linen_state() + result = convert_linen_to_nnx(state) + # Arrays should be wrapped in {"value": array} + kernel = result["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_opt_state_under_optimizer(self): + state = self._make_linen_state(add_opt_state=True) + result = convert_linen_to_nnx(state) + self.assertIn("opt_state", result["optimizer"]) + # Linen opt_state had nested 'params' level; it should be removed + self.assertNotIn("params", result["optimizer"]["opt_state"]) + + def test_no_step_produces_no_optimizer_step(self): + arr = _make_array(2, 4, 3) + state = {"params": {"params": {"decoder": {"layers": {"kernel": arr}}}}} + result = convert_linen_to_nnx(state) + self.assertNotIn("step", result) + self.assertIn("model", result) + + def test_no_double_nesting_still_converts(self): + # Linen state without double-nesting (unusual but handled) + arr = _make_array(2, 4) + state = {"params": {"decoder": {"layers": {"kernel": arr}}}} + result = convert_linen_to_nnx(state) + self.assertIn("decoder", result["model"]) + + def test_no_params_key_only_step(self): + state = {"step": 3} + result = convert_linen_to_nnx(state) + self.assertEqual(result["optimizer"]["step"], 3) + self.assertNotIn("model", result) + + def test_with_per_layer_params_stacked_and_transposed(self): + # Linen checkpoint with layers_0, layers_1 → stacked + transposed to axis 1 + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["decoder"]["layers"]["mlp"]["kernel"]["value"] + # Original (3, 4) stacked → (2, 3, 4), transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestConvertNNXToLinen(unittest.TestCase): + """Tests for the convert_nnx_to_linen function.""" + + def _make_nnx_state(self, add_opt_state=False): + """Creates an NNX checkpoint with 'model' and 'optimizer' keys. + + Uses 'attention' (not 'layers') as the sub-key so _convert_layers_to_linen_format + does not try to unstack the data. + """ + arr = _make_array(2, 4, 3) + state = { + "model": { + "decoder": { + "attention": {"wi": {"kernel": {"value": arr}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + "optimizer": {"step": 5}, + } + if add_opt_state: + state["optimizer"]["opt_state"] = { + "mu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "nu": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + } + return state + + def test_converts_step(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + + def test_adds_double_nesting(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + self.assertIn("params", result["params"]) + self.assertIn("decoder", result["params"]["params"]) + + def test_strips_value_wrappers(self): + state = self._make_nnx_state() + result = convert_nnx_to_linen(state) + kernel = result["params"]["params"]["decoder"]["attention"]["wi"]["kernel"] + self.assertIsInstance(kernel, np.ndarray) + + def test_converts_opt_state(self): + state = self._make_nnx_state(add_opt_state=True) + result = convert_nnx_to_linen(state) + self.assertIn("opt_state", result) + # mu/nu should get a 'params' level added + self.assertIn("params", result["opt_state"]["mu"]) + self.assertIn("params", result["opt_state"]["nu"]) + + def test_backward_compat_params_key(self): + # Old NNX format: "params" instead of "model", top-level "step" + arr = _make_array(2, 4, 3) + state = { + "step": 5, + "params": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + "decoder_norm": {"scale": {"value": _make_array(4)}}, + } + }, + } + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 5) + self.assertIn("decoder", result["params"]["params"]) + + def test_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + +class TestRoundTrip(unittest.TestCase): + """Verifies that linen->nnx->linen round-trip preserves data.""" + + def test_linen_to_nnx_to_linen(self): + # Use "attention" (not "layers") so _convert_layers_to_linen_format + # does not try to unstack the dict as a stacked-layers tensor. + arr = _make_array(2, 4, 3) + linen_state = { + "step": 42, + "params": { + "params": { + "decoder": { + "attention": {"mlp": {"wi": {"kernel": arr}}}, + "norm": {"scale": _make_array(4)}, + } + } + }, + } + nnx_state = convert_linen_to_nnx(linen_state) + recovered_state = convert_nnx_to_linen(nnx_state) + + self.assertEqual(recovered_state["step"], 42) + recovered_kernel = recovered_state["params"]["params"]["decoder"]["attention"]["mlp"]["wi"]["kernel"] + np.testing.assert_array_equal(recovered_kernel, arr) + + def test_nnx_to_linen_to_nnx(self): + arr = _make_array(2, 4, 3) + nnx_state = { + "model": { + "decoder": { + "layers": {"mlp": {"wi": {"kernel": {"value": arr}}}}, + } + }, + "optimizer": {"step": 7}, + } + linen_state = convert_nnx_to_linen(nnx_state) + recovered_state = convert_linen_to_nnx(linen_state) + + self.assertEqual(recovered_state["optimizer"]["step"], 7) + recovered_kernel = recovered_state["model"]["decoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIn("value", recovered_kernel) + np.testing.assert_array_equal(recovered_kernel["value"], arr) + + +class TestConvertOptState(unittest.TestCase): + """Tests for the _convert_opt_state_linen_to_nnx and _convert_opt_state_nnx_to_linen helpers.""" + + def test_linen_to_nnx_removes_params_level(self): + arr = _make_array(3, 4) + opt_state = {"mu": {"params": {"decoder": {"kernel": arr}}}} + result = _convert_opt_state_linen_to_nnx(opt_state) + # 'params' key removed; decoder promoted + self.assertNotIn("params", result["mu"]) + self.assertIn("decoder", result["mu"]) + # Arrays are plain (no value wrappers in NNX opt_state) + np.testing.assert_array_equal(result["mu"]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": arr}}, {"decoder": {"kernel": arr}}] + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": arr}},) + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_linen_to_nnx_handles_non_array_non_dict(self): + # Scalars should be passed through unchanged + result = _convert_opt_state_linen_to_nnx(42) + self.assertEqual(result, 42) + + def test_linen_to_nnx_params_key_with_non_dict_value(self): + # When k == "params" but converted value is not a dict, store it as-is + opt_state = {"params": 99} + result = _convert_opt_state_linen_to_nnx(opt_state) + self.assertIn("params", result) + self.assertEqual(result["params"], 99) + + def test_nnx_to_linen_adds_params_level_and_strips(self): + arr = _make_array(3, 4) + opt_state = { + "mu": {"decoder": {"kernel": {"value": arr}}}, + "nu": {"decoder": {"kernel": {"value": arr}}}, + } + result = _convert_opt_state_nnx_to_linen(opt_state) + # mu/nu should have 'params' nested inside + self.assertIn("params", result["mu"]) + self.assertIn("params", result["nu"]) + # Arrays unwrapped + kernel = result["mu"]["params"]["decoder"]["kernel"] + np.testing.assert_array_equal(kernel, arr) + + def test_nnx_to_linen_handles_list_input(self): + arr = _make_array(2, 2) + opt_state = [{"decoder": {"kernel": {"value": arr}}}] + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, list) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_handles_tuple_input(self): + arr = _make_array(2, 2) + opt_state = ({"decoder": {"kernel": {"value": arr}}},) + result = _convert_opt_state_nnx_to_linen(opt_state) + self.assertIsInstance(result, tuple) + np.testing.assert_array_equal(result[0]["decoder"]["kernel"], arr) + + def test_nnx_to_linen_passes_through_scalars(self): + result = _convert_opt_state_nnx_to_linen("scalar_string") + self.assertEqual(result, "scalar_string") + + def test_nnx_to_linen_value_wrapper_with_non_array_inner(self): + # {"value": scalar} should NOT be unwrapped (only arrays get unwrapped) + d = {"value": 42} + result = _convert_opt_state_nnx_to_linen(d) + self.assertIn("value", result) + self.assertEqual(result["value"], 42) + + +class TestConvertLinenToNNXEncoder(unittest.TestCase): + """Tests encoder path in convert_linen_to_nnx.""" + + def test_converts_encoder_params(self): + arr = _make_array(2, 4, 3) + state = { + "params": { + "params": { + "encoder": { + "layers": {"mlp": {"wi": {"kernel": arr}}}, + } + } + } + } + result = convert_linen_to_nnx(state) + self.assertIn("encoder", result["model"]) + kernel = result["model"]["encoder"]["layers"]["mlp"]["wi"]["kernel"] + self.assertIsInstance(kernel, dict) + self.assertIn("value", kernel) + + def test_converts_encoder_with_per_layer_stacking(self): + arr = _make_array(3, 4) + state = { + "params": { + "params": { + "encoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + result = convert_linen_to_nnx(state) + stacked = result["model"]["encoder"]["layers"]["mlp"]["kernel"]["value"] + # Stacked at axis 0 → (2, 3, 4), then transposed to (3, 2, 4) + self.assertEqual(stacked.shape, (3, 2, 4)) + + +class TestAdditionalEdgeCases(unittest.TestCase): + """Covers remaining edge cases.""" + + def test_detect_format_params_has_params_but_no_decoder_encoder(self): + # params["params"] exists but inner has no decoder/encoder -> falls through + # no optimizer/opt_state -> should raise + state = {"params": {"params": {"some_other_key": {}}}} + with self.assertRaises(ValueError): + detect_format(state) + + def test_detect_format_opt_state_returns_linen(self): + # Any state with "opt_state" (but no "model"/"optimizer") detects as linen + arr = _make_array(2) + state = { + "params": {"something": arr}, + "opt_state": {"mu": {"decoder": {"kernel": arr}}}, + } + self.assertEqual(detect_format(state), "linen") + + def test_add_value_wrappers_value_key_with_non_array(self): + # {"value": "text"} is not a wrapper (inner is not an array), recurse normally + d = {"value": "not_an_array"} + result = _add_value_wrappers(d) + self.assertEqual(result, {"value": "not_an_array"}) + + def test_convert_nnx_to_linen_no_step(self): + arr = _make_array(2, 4) + state = {"model": {"decoder": {"layers": {"kernel": {"value": arr}}}}} + result = convert_nnx_to_linen(state) + self.assertNotIn("step", result) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_already_has_params_nesting(self): + arr = _make_array(2, 4) + state = {"params": {"params": {"decoder": {"layers": {"kernel": {"value": arr}}}}}} + result = convert_nnx_to_linen(state) + self.assertIn("params", result) + + def test_convert_nnx_to_linen_no_params_key(self): + state = {"optimizer": {"step": 8}} + result = convert_nnx_to_linen(state) + self.assertEqual(result["step"], 8) + self.assertNotIn("params", result) + + +class TestLoadCheckpoint(unittest.TestCase): + """Tests for load_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_calls_checkpointer_and_returns_state(self, mock_epath, mock_ocp): + arr = _make_array(2, 2) + expected_state = {"params": arr, "step": 0} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {"params": arr} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.ArrayRestoreArgs.return_value = MagicMock() + + result = load_checkpoint("/tmp/test_ckpt") + + mock_epath.Path.assert_called_once_with("/tmp/test_ckpt") + mock_ocp.Checkpointer.assert_called_once() + mock_ckptr.metadata.assert_called_once_with(mock_path) + mock_ckptr.restore.assert_called_once() + self.assertEqual(result, expected_state) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_load_checkpoint_with_empty_tree_metadata(self, mock_epath, mock_ocp): + expected_state = {"step": 5} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_metadata = MagicMock() + mock_metadata.item_metadata.tree = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = mock_metadata + mock_ckptr.restore.return_value = expected_state + mock_ocp.Checkpointer.return_value = mock_ckptr + + result = load_checkpoint("/tmp/empty_ckpt") + + self.assertEqual(result["step"], 5) + + +class TestSaveCheckpoint(unittest.TestCase): + """Tests for save_checkpoint with mocked orbax/epath.""" + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_creates_dir_and_saves(self, mock_epath, mock_ocp): + state = {"params": _make_array(2, 2), "step": 1} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/output") + + mock_epath.Path.assert_called_once_with("/tmp/output") + mock_path.mkdir.assert_called_once_with(exist_ok=True, parents=True) + mock_ocp.PyTreeCheckpointer.assert_called_once() + mock_ckptr.save.assert_called_once_with(mock_path, state, force=True) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.ocp") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.epath") + def test_save_checkpoint_passes_state_unchanged(self, mock_epath, mock_ocp): + state = {"step": 99, "params": {"decoder": {}}} + + mock_path = MagicMock() + mock_epath.Path.return_value = mock_path + mock_ckptr = MagicMock() + mock_ocp.PyTreeCheckpointer.return_value = mock_ckptr + + save_checkpoint(state, "/tmp/out2") + + call_args = mock_ckptr.save.call_args + self.assertIs(call_args[0][1], state) + + +class TestMain(unittest.TestCase): + """Tests for the main() CLI entry point.""" + + def _run_main(self, argv): + with patch("sys.argv", ["prog"] + argv): + main() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_linen_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 1, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # NNX format: decoder at top level of model + self.assertIn("decoder", saved_state["model"]) + self.assertEqual(mock_save.call_args[0][1], "/dst") + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_explicit_nnx_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 2}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=nnx_to_linen"]) + mock_load.assert_called_once_with("/src") + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Linen format: double nesting + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_linen_converts_to_nnx(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "step": 3, + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected linen → NNX format: model key + self.assertIn("decoder", saved_state["model"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_auto_detects_nnx_converts_to_linen(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "model": {"decoder": {"layers": {"kernel": {"value": arr}}}}, + "optimizer": {"step": 4}, + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=auto"]) + mock_save.assert_called_once() + saved_state = mock_save.call_args[0][0] + # Auto-detected nnx → Linen format + self.assertIn("params", saved_state["params"]) + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_default_direction_is_auto(self, mock_load, mock_save): + arr = _make_array(2, 4, 3) + mock_load.return_value = { + "params": {"params": {"decoder": {"layers": {"kernel": arr}}}}, + } + # No --direction arg -> defaults to "auto" + self._run_main(["--source_path=/src", "--target_path=/dst"]) + mock_save.assert_called_once() + + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.save_checkpoint") + @patch("maxtext.checkpoint_conversion.linen_nnx_converter.load_checkpoint") + def test_main_scan_layers_false(self, mock_load, mock_save): + arr = _make_array(3, 4) + mock_load.return_value = { + "params": { + "params": { + "decoder": { + "layers_0": {"mlp": {"kernel": arr}}, + "layers_1": {"mlp": {"kernel": arr}}, + } + } + } + } + self._run_main(["--source_path=/src", "--target_path=/dst", "--direction=linen_to_nnx", "--no-scan_layers"]) + saved_state = mock_save.call_args[0][0] + # With scan_layers=False: integer-keyed layers/N + layers = saved_state["model"]["decoder"]["layers"] + self.assertIsInstance(layers, dict) + self.assertTrue(all(k.isdigit() for k in layers.keys())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/max_utils_test.py b/tests/unit/max_utils_test.py index 5eba20f807..af0b331711 100644 --- a/tests/unit/max_utils_test.py +++ b/tests/unit/max_utils_test.py @@ -160,6 +160,8 @@ def test_unscan_train_state_params(self): """Test unscan_train_state_params logic and performance with a real model.""" # Initialize a configuration for an 8B model. config = self.init_pyconfig() + if config.pure_nnx: + self.skipTest("test_unscan_train_state_params uses Linen state.params; NNX path not yet covered.") _, _, sharding, _, mesh, *_, state = setup_train_loop(config, None) diff --git a/tests/unit/maxengine_test.py b/tests/unit/maxengine_test.py index fa712672d2..d94c7ca53d 100644 --- a/tests/unit/maxengine_test.py +++ b/tests/unit/maxengine_test.py @@ -42,6 +42,8 @@ class MaxEngineTest(unittest.TestCase): def setUp(self): super().setUp() self.cfg = self.init_pyconfig() + if self.cfg.pure_nnx: + self.skipTest("Pure NNX support has not been implemented yet in MaxEngine.") self.rng = jax.random.PRNGKey(0) def init_pyconfig(self, **kwargs): @@ -82,6 +84,8 @@ def test_stack_and_unstack_prefill_cache(self): enable_checkpointing=False, stack_prefill_result_cache=True, ) + if config.pure_nnx: + self.skipTest("Pure NNX support has not been implemented yet in MaxEngine.") engine = maxengine.MaxEngine(config, jax.devices()) num_layers = engine.config.num_decoder_layers input_d = { diff --git a/tests/unit/maxtext_utils_nnx_test.py b/tests/unit/maxtext_utils_nnx_test.py new file mode 100644 index 0000000000..0eb1f7ef77 --- /dev/null +++ b/tests/unit/maxtext_utils_nnx_test.py @@ -0,0 +1,182 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Tests for the common MaxText NNX utilities """ +import unittest +from dataclasses import dataclass +from typing import Any +import jax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +from jax.experimental import mesh_utils + +from maxtext.utils import maxtext_utils_nnx + + +class TestMaxTextUtilsNNX(unittest.TestCase): + """Test the functions for MaxText Utils.""" + + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" + + init_weights_seed: int = 42 + + class TinyModel(nnx.Module): + """ + A tiny NNX model with logical annotations. + Annotations are required to test that sharding extraction logic works. + """ + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear( + jax.device_count(), + jax.device_count(), + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)), + # FIX: Removed () from zeros. zeros is the initializer function itself, + # not a factory like lecun_normal(). + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)), + rngs=rngs, + ) + + def tiny_model_init_fn(self): + """Factory function for model initialization.""" + return self.TinyModel(rngs=nnx.Rngs(0)) + + def setUp(self): + # Create a mesh for sharding tests. + # NamedSharding requires an active Mesh to resolve logical names. + self.devices = mesh_utils.create_device_mesh((jax.device_count(),)) + self.mesh = Mesh(self.devices, axis_names=("data",)) + + def test_create_nnx_rngs_training(self): + # Using Any to satisfy static type checkers for the MockConfig + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True) + + self.assertIsInstance(rngs, nnx.Rngs) + # FIX: nnx.Rngs does not have a .streams attribute. + # Check for stream attributes directly on the object. + self.assertTrue(hasattr(rngs, "params")) + self.assertTrue(hasattr(rngs, "dropout")) + self.assertTrue(hasattr(rngs, "aqt")) + + def test_create_nnx_rngs_inference(self): + config: Any = self.MockConfig(init_weights_seed=123) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False) + + self.assertIsInstance(rngs, nnx.Rngs) + # Check that 'params' exists but 'dropout' and 'aqt' were excluded + self.assertTrue(hasattr(rngs, "params")) + self.assertFalse(hasattr(rngs, "dropout")) + self.assertFalse(hasattr(rngs, "aqt")) + + def test_move_memory(self): + sharding = NamedSharding(self.mesh, P("data")) + self.assertNotEqual(sharding.memory_kind, "pinned_host") + + path = ("layers", "linear", "kernel") + host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding) + + self.assertEqual(host_sharding.memory_kind, "pinned_host") + self.assertEqual(host_sharding.spec, P("data")) + + device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding) + + self.assertEqual(device_sharding.memory_kind, "device") + self.assertEqual(device_sharding.spec, P("data")) + + def test_get_set_named_sharding_nnx(self): + # 1. Create the abstract state using standard NNX functional API + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + + # 2. Test extraction + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # Verify kernel and bias match the P("data") annotations from TinyModel + self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) + self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data")) + + # Target kernel spec update + new_kernel_spec = P(None, "data") + + def update_spec_fn(path, leaf_sharding): + path_str = jax.tree_util.keystr(path) + if "linear" in path_str and "kernel" in path_str: + # Construct a new NamedSharding with the requested logical spec + return NamedSharding(leaf_sharding.mesh, new_kernel_spec) + return leaf_sharding + + # Apply the spec change to the extracted sharding tree + extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings) + + # 3. Test setting new shardings + # Transform the extracted shardings to host memory + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings) + + # Verify the metadata inside the abstract state leaf has updated its sharding + self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host") + # Also verify the spec was updated successfully + self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec) + + # 4. Verify named sharding is preserved after NNX merge (update) and split (state) + model = self.tiny_model_init_fn() + nnx.update(model, updated_abstract) + re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) + + # Verify kernel and bias have expected sharding + self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) + self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data")) + + def test_create_nnx_sharded_model(self): + # 1. Create abstract model + graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + abstract_model = nnx.merge(graphdef, abstract_state) + + # 2. Modify shardings to trigger host offloading + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) + + # 3. Run the sharded creation + # We pass the abstract model and use the custom sharding for instantiation + sharded_model = maxtext_utils_nnx.create_nnx_sharded_model( + abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings + ) + + # 4. Verify the model is concrete (contains Arrays) and sharded on host + self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array) + self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host") + + def test_get_partition_spec_nnx(self): + """Verifies extraction of PartitionSpecs from NamedShardings.""" + # 1. Create abstract state and get sharding + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + # 2. Execute extraction + spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) + + # 3. Verify that the leaves are now raw PartitionSpecs + # Expected values derived from TinyModel definition + expected_spec_k = P("data", None) + expected_spec_b = P("data") + + self.assertEqual(spec["linear"]["kernel"], expected_spec_k) + self.assertEqual(spec["linear"]["bias"], expected_spec_b) + self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index a65a905c7f..2172004130 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -14,10 +14,14 @@ """Tests for the common MaxText utilities""" +import functools from collections.abc import Callable -from typing import Any +from typing import Any, Sequence import unittest -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch +from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -26,9 +30,10 @@ import jax from jax import random, vmap import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig -from maxtext.common.common_types import MODEL_MODE_TRAIN +from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils from maxtext.layers import quantizations from maxtext.models import models @@ -37,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -177,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -194,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -262,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -279,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -348,21 +348,27 @@ def setUp(self): # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode extra_args = get_decoupled_parallelism_overrides() self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) + if self.config.pure_nnx: + self.skipTest("Pure NNX support has not been implemented yet.") devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) - self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + self.model = models.transformer_as_linen(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) def test_setup_decode_state(self): rng = random.PRNGKey(0) - state, _ = maxtext_utils.setup_decode_state(self.model, self.config, rng, self.mesh, None) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + state, _ = maxtext_utils.setup_decode_state(self.config, self.mesh, None, init_state_fn) self.assertEqual(state.tx, None) self.assertEqual(state.opt_state, {}) def test_setup_initial_state(self): rng = random.PRNGKey(0) tx = optax.adam(learning_rate=0.001) - state, _, _, _ = maxtext_utils.setup_initial_state(self.model, None, tx, self.config, rng, self.mesh, None) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) + state, _, _, _ = maxtext_utils.setup_initial_state( # type: ignore[arg-type] + None, self.config, self.mesh, None, init_state_fn + ) self.assertEqual(state.tx, tx) self.assertNotEqual(state.opt_state, {}) @@ -908,19 +914,440 @@ def test_wsd_schedule(self): self.assertIn("wsd_decay_steps_fraction", str(cm.exception)) -class TestGetAbstractState(unittest.TestCase): - """Test class for get_abstract_state.""" +class TestMeshUtils(unittest.TestCase): + """Test suite for the mesh creation utility function.""" + + @dataclass + class MockConfig: + """Minimal mock for pyconfig.HyperParameters.""" + + init_weights_seed: int = 42 + shard_mode: str = ShardMode.EXPLICIT + mesh_axes: Sequence[str] = field(default_factory=lambda: ["data", "model"]) def setUp(self): - extra_args = get_decoupled_parallelism_overrides() - self.config = pyconfig.initialize( - [None, get_test_config_path()], - **extra_args, - enable_checkpointing=False, - model_name="llama3.1-8b", - per_device_batch_size=1, - max_target_length=16, + # Setup a dummy device array for the mock to return + self.devices_array = np.array(jax.devices()) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_explicit_mode(self, mock_create_device_mesh): + """Tests that ShardMode.EXPLICIT sets axis_types to MANUAL.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:1].reshape((1,)) + config = self.MockConfig(shard_mode=ShardMode.EXPLICIT, mesh_axes=["data"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + # Check that the internal utility was called correctly + mock_create_device_mesh.assert_called_once_with(config, None) + + # Verify Mesh properties + self.assertEqual(mesh.axis_names, ("data",)) + # In JAX, AxisType.MANUAL is the equivalent for explicit control + self.assertEqual(mesh.axis_types, (AxisType.Explicit,)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_auto_mode(self, mock_create_device_mesh): + """Tests that ShardMode.AUTO sets axis_types to AUTO.""" + # 1. Setup Mock + mock_create_device_mesh.return_value = self.devices_array[:2].reshape((2, 1)) + config = self.MockConfig(shard_mode=ShardMode.AUTO, mesh_axes=["data", "model"]) + + # 2. Run function + mesh = maxtext_utils.get_mesh_from_config(config) + + # 3. Assertions + self.assertEqual(len(mesh.axis_types), 2) + self.assertTrue(all(t == AxisType.Auto for t in mesh.axis_types)) + + @patch("MaxText.maxtext_utils.create_device_mesh") + def test_get_mesh_with_provided_devices(self, mock_create_device_mesh): + """Tests that provided devices are passed through to the mesh creator.""" + config = self.MockConfig() + specific_devices = self.devices_array[:2].reshape((1, 2)) + mock_create_device_mesh.return_value = specific_devices + + _ = maxtext_utils.get_mesh_from_config(config, devices=specific_devices) + + # Verify the second argument to create_device_mesh was our device list + mock_create_device_mesh.assert_called_once_with(config, specific_devices) + + +class TestGetFunctionalTrainWithSignature(unittest.TestCase): + """Tests for get_functional_train_with_signature.""" + + def _make_mock_step(self): + def train_step(_model, _config, _state_shardings, _params_shardings, state, _batch, _rng=None): + return state, {} + + return train_step + + def _make_mock_config(self, pure_nnx=True): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + + def test_returns_five_tuple(self): + step = self._make_mock_step() + result = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config() + ) + self.assertEqual(len(result), 5) + + def test_functional_train_has_correct_name(self): + step = self._make_mock_step() + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config() + ) + self.assertEqual(fn.__name__, "train_step") + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + self.assertIsNone(in_shardings[2]) # rng sharding is None + + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) ) + self.assertEqual(len(in_shardings), 2) + + def test_donate_argnums_is_zero(self): + step = self._make_mock_step() + _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config() + ) + self.assertEqual(donate_argnums, 0) + + def test_functional_train_is_partial(self): + """functional_train should partially apply model and config.""" + received = {} + cfg = self._make_mock_config() + + def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): + received["model"] = model + received["config"] = config + return state, {} + + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) + fn("state", "batch") + self.assertEqual(received["model"], "my_model") + self.assertEqual(received["config"], cfg) + + +class TestGetFunctionalEvalWithSignature(unittest.TestCase): + """Tests for get_functional_eval_with_signature.""" + + def _make_mock_eval_step(self): + def eval_step(_model, _config, _state, _batch, _rng=None): + return {} + + return eval_step + + def _make_mock_config(self, pure_nnx=True): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + + def test_returns_five_tuple(self): + step = self._make_mock_eval_step() + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) + self.assertEqual(len(result), 5) + + def test_functional_eval_has_correct_name(self): + step = self._make_mock_eval_step() + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) + self.assertEqual(fn.__name__, "eval_step") + + def test_out_shardings_is_none(self): + step = self._make_mock_eval_step() + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) + self.assertIsNone(out_shardings) + + def test_donate_argnums_is_empty(self): + step = self._make_mock_eval_step() + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) + self.assertEqual(donate_argnums, ()) + + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + + +class TestGetShapedBatch(unittest.TestCase): + """Tests for get_shaped_batch.""" + + def _make_cfg(self, *, enable_diloco=False, use_multimodal=False, use_audio=False): + cfg = MagicMock() + cfg.enable_diloco = enable_diloco + cfg.global_batch_size_to_load = 4 + cfg.max_target_length = 16 + cfg.use_multimodal = use_multimodal + cfg.use_audio = use_audio + if enable_diloco: + cfg.num_diloco_replicas = 2 + return cfg + + def test_standard_keys_present(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg()) + for key in ( + "inputs", + "inputs_position", + "inputs_segmentation", + "targets", + "targets_position", + "targets_segmentation", + ): + self.assertIn(key, batch) + + def test_standard_shape(self): + cfg = self._make_cfg() + batch = maxtext_utils.get_shaped_batch(cfg) + expected_shape = (cfg.global_batch_size_to_load, cfg.max_target_length) + self.assertEqual(batch["inputs"].shape, expected_shape) + + def test_diloco_shape(self): + cfg = self._make_cfg(enable_diloco=True) + batch = maxtext_utils.get_shaped_batch(cfg) + expected_shape = ( + cfg.num_diloco_replicas, + cfg.global_batch_size_to_load // cfg.num_diloco_replicas, + cfg.max_target_length, + ) + self.assertEqual(batch["inputs"].shape, expected_shape) + + def test_no_image_key_without_multimodal(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_multimodal=False)) + self.assertNotIn("images", batch) + + def test_no_audio_key_without_audio(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg(use_audio=False)) + self.assertNotIn("audios", batch) + + def test_all_values_are_shape_dtype_struct(self): + batch = maxtext_utils.get_shaped_batch(self._make_cfg()) + for v in batch.values(): + self.assertIsInstance(v, jax.ShapeDtypeStruct) + + +class TestShouldPreventCseInRemat(unittest.TestCase): + """Tests for should_prevent_cse_in_remat.""" + + def _make_cfg(self, scan_layers=False, gradient_accumulation_steps=1, hardware="tpu"): + cfg = MagicMock() + cfg.scan_layers = scan_layers + cfg.gradient_accumulation_steps = gradient_accumulation_steps + cfg.hardware = hardware + return cfg + + def test_scan_layers_returns_false(self): + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg(scan_layers=True))) + + def test_gpu_with_grad_accum_returns_false(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu") + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_gpu_multiprocess_with_grad_accum_returns_false(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="gpu_multiprocess") + self.assertFalse(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_tpu_with_grad_accum_returns_true(self): + cfg = self._make_cfg(scan_layers=False, gradient_accumulation_steps=4, hardware="tpu") + self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(cfg)) + + def test_default_case_returns_true(self): + self.assertTrue(maxtext_utils.should_prevent_cse_in_remat(self._make_cfg())) + + +class TestCalculateTokensTrainingPerDevice(unittest.TestCase): + """Tests for calculate_tokens_training_per_device.""" + + def test_basic_calculation(self): + cfg = MagicMock() + cfg.max_target_length = 128 + cfg.per_device_batch_size = 2 + cfg.gradient_accumulation_steps = 4 + result = maxtext_utils.calculate_tokens_training_per_device(cfg) + self.assertEqual(result, 128 * 2 * 4) + + +class TestCalculateIndexerMaskRatio(unittest.TestCase): + """Tests for calculate_indexer_mask_ratio.""" + + def test_half_topk(self): + # K=T/2: ratio=0.5, mask = 0.5 - 0.5*0.25 = 0.375 + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=4, max_target_length=8) + self.assertAlmostEqual(result, 0.375, places=6) + + def test_full_topk_equals_dense(self): + # K=T: ratio=1, mask = 1 - 0.5 = 0.5 (same as causal) + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=8, max_target_length=8) + self.assertAlmostEqual(result, 0.5, places=6) + + def test_small_topk(self): + # K=1, T=100: ratio=0.01, mask ≈ 0.01 - 0.5*0.0001 ≈ 0.00995 + result = maxtext_utils.calculate_indexer_mask_ratio(indexer_topk=1, max_target_length=100) + expected = 0.01 - 0.5 * (0.01**2) + self.assertAlmostEqual(result, expected, places=8) + + +class TestCalculateFfnMatmulTflops(unittest.TestCase): + """Tests for calculate_ffn_mamtul_tflops_per_device.""" + + def _make_cfg(self, num_activations=2): + cfg = MagicMock() + cfg.per_device_batch_size = 1 + cfg.max_target_length = 64 + cfg.emb_dim = 512 + cfg.mlp_activations = ["silu"] * num_activations + return cfg + + def test_total_flops_positive(self): + result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(self._make_cfg(), mlp_dim=2048) + self.assertGreater(result, 0) + + def test_scales_with_mlp_dim(self): + cfg = self._make_cfg() + small = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=1024) + large = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=4096) + self.assertGreater(large, small) + + def test_single_activation(self): + """With one activation, ffn1 uses 1x mlp_dim.""" + cfg = self._make_cfg(num_activations=1) + result = maxtext_utils.calculate_ffn_mamtul_tflops_per_device(cfg, mlp_dim=2048) + expected_ffn1 = 2 * 1 * 64 * 2048 * 512 * 1 + expected_ffn2 = 2 * 1 * 64 * 2048 * 512 + self.assertEqual(result, expected_ffn1 + expected_ffn2) + + +class TestGetDenseMoeLayers(unittest.TestCase): + """Tests for get_dense_moe_layers.""" + + def _make_cfg(self, decoder_block, num_decoder_layers=32, first_num_dense_layers=3, interleave_moe_layer_step=4): + cfg = MagicMock() + cfg.decoder_block = decoder_block + cfg.num_decoder_layers = num_decoder_layers + cfg.first_num_dense_layers = first_num_dense_layers + cfg.interleave_moe_layer_step = interleave_moe_layer_step + return cfg + + def test_deepseek_block(self): + cfg = self._make_cfg(DecoderBlockType.DEEPSEEK, num_decoder_layers=32, first_num_dense_layers=3) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(dense, 3) + self.assertEqual(moe, 29) + + def test_llama4_block(self): + cfg = self._make_cfg(DecoderBlockType.LLAMA4, num_decoder_layers=16, interleave_moe_layer_step=4) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(moe, 4) # 16 // 4 + self.assertEqual(dense, 12) # 16 - 4 + + def test_qwen3_next_block(self): + cfg = self._make_cfg(DecoderBlockType.QWEN3_NEXT, num_decoder_layers=8) + dense, moe = maxtext_utils.get_dense_moe_layers(cfg) + self.assertEqual(dense, 0) + self.assertEqual(moe, 8) + + def test_unsupported_block_raises(self): + cfg = self._make_cfg(DecoderBlockType.DEFAULT) + with self.assertRaises(ValueError): + maxtext_utils.get_dense_moe_layers(cfg) + + +class TestCalculatePrefillTflops(unittest.TestCase): + """Tests for calculate_prefill_tflops_per_device.""" + + def _make_cfg(self, num_query_heads=8, num_decoder_layers=2, head_dim=64): + cfg = MagicMock() + cfg.num_query_heads = num_query_heads + cfg.num_decoder_layers = num_decoder_layers + cfg.head_dim = head_dim + return cfg + + def test_returns_three_positive_values(self): + cfg = self._make_cfg() + total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False + ) + self.assertGreater(total, 0) + self.assertGreater(lw, 0) + self.assertGreater(attn, 0) + + def test_total_is_sum_of_parts(self): + cfg = self._make_cfg() + total, lw, attn = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=500_000, prefill_length=64, config=cfg, log=False + ) + self.assertAlmostEqual(total, lw + attn, places=10) + + def test_scales_with_prefill_length(self): + cfg = self._make_cfg() + _, _, attn_short = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=64, config=cfg, log=False + ) + _, _, attn_long = maxtext_utils.calculate_prefill_tflops_per_device( + num_model_parameters=1_000_000, prefill_length=128, config=cfg, log=False + ) + # Attention scales quadratically with prefill length + self.assertGreater(attn_long, attn_short * 3) + + +class TestSetupTrainingState(unittest.TestCase): + """Tests for setup_training_state (thin wrapper over setup_initial_state).""" + + def setUp(self): + extra_args = get_decoupled_parallelism_overrides() + self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) + if self.config.pure_nnx: + self.skipTest("Pure NNX path not covered by this test.") + devices_array = maxtext_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + quant = quantizations.configure_quantization(self.config) + self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + + def test_setup_training_state_returns_train_state(self): + rng = jax.random.PRNGKey(0) + tx = optax.adam(learning_rate=0.001) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, tx, self.config, True, rng) + state, _, _, _ = maxtext_utils.setup_training_state(None, self.config, self.mesh, None, init_state_fn) + self.assertEqual(state.tx, tx) + self.assertNotEqual(state.opt_state, {}) + + +class TestGetLogicalAnnotations(unittest.TestCase): + """Tests for get_logical_annotations.""" + + def setUp(self): + extra_args = get_decoupled_parallelism_overrides() + self.config = pyconfig.initialize([None, get_test_config_path()], enable_checkpointing=False, **extra_args) + if self.config.pure_nnx: + self.skipTest("Pure NNX path not covered by this test.") devices_array = maxtext_utils.create_device_mesh(self.config) self.mesh = Mesh(devices_array, self.config.mesh_axes) quant = quantizations.configure_quantization(self.config) @@ -928,18 +1355,193 @@ def setUp(self): self.rng = jax.random.PRNGKey(0) self.tx = optax.adam(learning_rate=0.001) - def test_get_abstract_state(self): - """Tests that get_abstract_state returns abstract arrays.""" - # get_abstract_state returns a tuple, the first element is the abstract state. - abstract_state, _, _ = maxtext_utils.get_abstract_state(self.model, self.tx, self.config, self.rng, self.mesh, None) + def test_returns_partition_spec_tree(self): + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, self.tx, self.config, True, self.rng) + annotations = maxtext_utils.get_logical_annotations(self.config, self.mesh, init_state_fn) + # Result should be a pytree with PartitionSpec leaves + leaves = jax.tree_util.tree_leaves(annotations) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, PartitionSpec) + + +class TestSaveQuantizedCheckpoint(unittest.TestCase): + """Tests for save_quantized_checkpoint_if_configured.""" + + def test_raises_when_no_quantization(self): + cfg = MagicMock() + cfg.quantization = "" + with self.assertRaises(AssertionError): + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) + + @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") + def test_skips_save_when_path_empty(self, mock_ckpt): + cfg = MagicMock() + cfg.quantization = "int8" + cfg.save_quantized_params_path = "" + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={}) + mock_ckpt.save_params_to_path.assert_not_called() + + @unittest.mock.patch("maxtext.utils.maxtext_utils.checkpointing") + def test_calls_save_when_path_set(self, mock_ckpt): + cfg = MagicMock() + cfg.quantization = "int8" + cfg.save_quantized_params_path = "/tmp/quantized" + cfg.checkpoint_storage_use_ocdbt = True + cfg.checkpoint_storage_use_zarr3 = True + maxtext_utils.save_quantized_checkpoint_if_configured(cfg, params={"w": jnp.ones((2,))}) + mock_ckpt.save_params_to_path.assert_called_once() - # Check that params are abstract - param_leaves = jax.tree_util.tree_leaves(abstract_state.params) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in param_leaves)) - # Check that opt_state is abstract - opt_state_leaves = jax.tree_util.tree_leaves(abstract_state.opt_state) - self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves)) +class TestAddConfigToSummaryWriter(unittest.TestCase): + """Tests for add_config_to_summary_writer.""" + + def test_calls_add_text_for_each_key(self): + cfg = MagicMock() + cfg.get_keys.return_value = {"learning_rate": 0.001, "steps": 100} + mock_writer = MagicMock() + + with unittest.mock.patch("maxtext.utils.max_utils.add_text_to_summary_writer") as mock_add: + maxtext_utils.add_config_to_summary_writer(cfg, mock_writer) + # Should have been called once per config key (process_index==0 in tests) + if jax.process_index() == 0: + self.assertEqual(mock_add.call_count, 2) + + +class TestMaybeDumpJaxpr(unittest.TestCase): + """Tests for maybe_dump_jaxpr.""" + + def test_early_return_when_disabled(self): + cfg = MagicMock() + cfg.dump_jaxpr = False + # Should return immediately without calling any JAX tracing (no exception raised) + maxtext_utils.maybe_dump_jaxpr(cfg, p_train_step=None, train_step_inputs=None) + + +class TestPrintShardingsParams(unittest.TestCase): + """Tests for print_shardings_params — normalization branches.""" + + def setUp(self): + """Build a minimal mesh and sharded param for testing.""" + self.mesh = Mesh(np.array(jax.devices()), ("data",)) + + def _make_simple_params(self): + """Return (params, param_sharding, logical) without a .params attribute.""" + params = {"w": jnp.ones((4,))} + param_sharding = {"w": NamedSharding(self.mesh, PartitionSpec(None))} + logical = {"w": PartitionSpec(None)} + return params, param_sharding, logical + + def test_runs_without_error_dict_inputs(self): + """print_shardings_params should not raise with plain dict inputs.""" + params, param_sharding, logical = self._make_simple_params() + # Should complete without raising + maxtext_utils.print_shardings_params(params, param_sharding, logical) + + def test_runs_without_logical_annotations(self): + """logical_annotations=None should be handled (no logical column).""" + params, param_sharding, _ = self._make_simple_params() + maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) + + +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) if __name__ == "__main__": diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index 7f8c784176..ba4cb8817c 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2023–2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,18 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for model_creation_utils.""" +"""Unit tests for model_creation_utils.py.""" import dataclasses +import sys import unittest +from unittest.mock import MagicMock, patch import jax import jax.numpy as jnp - +import flax.linen as nn +from flax import nnx +from jax.sharding import Mesh from orbax import checkpoint as ocp -# Import the private helpers under test. +from maxtext.configs import pyconfig +from maxtext.common.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL +from maxtext.models import models +from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.utils.model_creation_utils import _fix_restore_args_for_shape_mismatch +from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides # --------------------------------------------------------------------------- @@ -41,10 +50,8 @@ def _is_fake_meta(x): # Monkey-patch the module-level helper so our fake metadata is recognised. -import maxtext.utils.model_creation_utils as _mcu - -_orig_is_orbax = _mcu._is_orbax_array_metadata # pylint: disable=protected-access -_mcu._is_orbax_array_metadata = _is_fake_meta # pylint: disable=protected-access +_orig_is_orbax = model_creation_utils._is_orbax_array_metadata # pylint: disable=protected-access +model_creation_utils._is_orbax_array_metadata = _is_fake_meta # pylint: disable=protected-access def _make_restore_arg(global_shape): @@ -59,6 +66,34 @@ def _make_restore_arg(global_shape): ) +def _make_config(**kwargs): + """Returns a minimal pyconfig suitable for model-creation tests.""" + extra = get_decoupled_parallelism_overrides() + defaults = { + "per_device_batch_size": 1.0, + "run_name": "test", + "enable_checkpointing": False, + "base_num_decoder_layers": 2, + "attention": "dot_product", + "max_target_length": 16, + "base_emb_dim": 256, + "base_num_query_heads": 2, + "base_num_kv_heads": 2, + "max_prefill_predict_length": 4, + } + defaults.update(kwargs) + return pyconfig.initialize( + [sys.argv[0], get_test_config_path()], + **defaults, + **extra, + ) + + +def _make_mesh(config): + devices_array = maxtext_utils.create_device_mesh(config) + return Mesh(devices_array, config.mesh_axes) + + class FixRestoreArgsRankGuardTest(unittest.TestCase): """_fix_restore_args_for_shape_mismatch must not touch args when stored rank != model rank.""" @@ -106,5 +141,258 @@ def test_scanned_both_same_rank_shape_mismatch_is_modified(self): self.assertIsNone(arg.global_shape) +class TestGetTransformerModel(unittest.TestCase): + """Tests for get_transformer_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_linen_module_when_rngs_is_none(self): + """Without rngs, should return a Linen nn.Module.""" + model = model_creation_utils.get_transformer_model(self.config, self.mesh, quant=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_returns_nnx_module_when_rngs_provided(self): + """With rngs, should return an NNX nnx.Module.""" + model = nnx.eval_shape( + lambda: model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, rngs=nnx.Rngs(params=0, dropout=1, aqt=2) + ) + ) + self.assertIsInstance(model, nnx.Module) + + def test_respects_model_mode_prefill(self): + """Linen model created with MODEL_MODE_PREFILL should differ from train mode.""" + linen_train = model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, model_mode=MODEL_MODE_TRAIN, rngs=None + ) + linen_prefill = model_creation_utils.get_transformer_model( + self.config, self.mesh, quant=None, model_mode=MODEL_MODE_PREFILL, rngs=None + ) + # Both are still nn.Module instances + self.assertIsInstance(linen_train, nn.Module) + self.assertIsInstance(linen_prefill, nn.Module) + + +class TestCreateModel(unittest.TestCase): + """Tests for create_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_linen_model_without_rngs(self): + model = model_creation_utils.create_model(self.config, self.mesh) + self.assertIsInstance(model, nn.Module) + + def test_returns_nnx_model_with_rngs(self): + model = nnx.eval_shape( + lambda: model_creation_utils.create_model(self.config, self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) + ) + self.assertIsInstance(model, nnx.Module) + + def test_model_mode_train_default(self): + """Default model_mode is MODEL_MODE_TRAIN.""" + model = model_creation_utils.create_model(self.config, self.mesh) + self.assertIsInstance(model, nn.Module) + + +class TestFromConfig(unittest.TestCase): + """Tests for from_config().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_linen_path_rngs_none(self): + """from_config with rngs=None should return a Linen nn.Module.""" + model = model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_nnx_path_with_rngs(self): + """from_config with rngs provided should return an NNX nnx.Module.""" + model = nnx.eval_shape( + lambda: model_creation_utils.from_config(self.config, mesh=self.mesh, rngs=nnx.Rngs(params=0, dropout=1, aqt=2)) + ) + self.assertIsInstance(model, nnx.Module) + + def test_mesh_created_from_devices_when_none(self): + """from_config should work when mesh is None (creates mesh internally).""" + model = model_creation_utils.from_config(self.config, devices=None, mesh=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_model_mode_is_forwarded(self): + """from_config should accept and forward model_mode.""" + model = model_creation_utils.from_config(self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL, rngs=None) + self.assertIsInstance(model, nn.Module) + + def test_explicit_shard_mode_creates_mesh_with_explicit_axis_types(self): + """from_config with shard_mode=explicit should create mesh using AxisType.Explicit.""" + cfg = _make_config(shard_mode="explicit") + # Should not raise; mesh is built with AxisType.Explicit for each axis + model = model_creation_utils.from_config(cfg, mesh=None, rngs=None) + self.assertIsInstance(model, nn.Module) + + +class TestCreateNNXAbstractModel(unittest.TestCase): + """Tests for create_nnx_abstract_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_returns_tuple_of_callable_and_module(self): + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_abstract_model_has_abstract_arrays(self): + """Abstract model leaves should be ShapeDtypeStruct, not concrete arrays.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + _, state = nnx.split(abstract_model) + leaves = jax.tree.leaves(state) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + # In abstract state, values are nnx.Variable wrapping abstract shapes/ShapeDtypeStruct + # Concrete jax.Array would have a .devices() method; abstract ones should not be Arrays + self.assertNotIsInstance(leaf, jax.Array) + + def test_create_fn_produces_concrete_model(self): + """The returned create_fn should produce a real (concrete) NNX Module.""" + create_fn, _ = model_creation_utils.create_nnx_abstract_model(self.config, mesh=self.mesh) + with self.mesh: + concrete = create_fn() + self.assertIsInstance(concrete, nnx.Module) + leaves = jax.tree.leaves(nnx.state(concrete)) + for leaf in leaves: + self.assertIsInstance(leaf, jax.Array) + + def test_works_without_explicit_mesh(self): + """create_nnx_abstract_model should work when mesh=None (from_config creates mesh).""" + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model(self.config, mesh=None) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_explicit_rng_key_is_used(self): + """Passing a rng_key should not raise and returns valid abstract model.""" + rng_key = jax.random.PRNGKey(42) + create_fn, abstract_model = model_creation_utils.create_nnx_abstract_model( + self.config, mesh=self.mesh, rng_key=rng_key + ) + self.assertTrue(callable(create_fn)) + self.assertIsInstance(abstract_model, nnx.Module) + + def test_prefill_model_mode(self): + """create_nnx_abstract_model should accept MODEL_MODE_PREFILL.""" + _, abstract_model = model_creation_utils.create_nnx_abstract_model( + self.config, mesh=self.mesh, model_mode=MODEL_MODE_PREFILL + ) + self.assertIsInstance(abstract_model, nnx.Module) + + +class TestCreateNnxModel(unittest.TestCase): + """Tests for create_nnx_model().""" + + def setUp(self): + self.config = _make_config() + self.mesh = _make_mesh(self.config) + + def test_no_checkpoint_returns_model_and_mesh(self): + """Without load_parameters_path, should return (model, mesh) cleanly.""" + model, mesh = model_creation_utils.create_nnx_model(self.config, self.mesh) + self.assertIsInstance(model, models.Transformer) + self.assertIsInstance(mesh, Mesh) + + def test_mesh_none_uses_abstract_model_mesh(self): + """When mesh=None is passed, the function resolves it from the abstract model.""" + model, mesh = model_creation_utils.create_nnx_model(self.config, mesh=None) + self.assertIsInstance(model, models.Transformer) + self.assertIsInstance(mesh, Mesh) + + def test_explicit_rng_key(self): + """An explicit rng_key should be accepted without error.""" + rng_key = jax.random.PRNGKey(99) + model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, rng_key=rng_key) + self.assertIsInstance(model, models.Transformer) + + def test_inference_mode_disables_dropout_rng(self): + """MODEL_MODE_PREFILL should create rngs without a dropout key.""" + model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, model_mode=MODEL_MODE_PREFILL) + self.assertIsInstance(model, models.Transformer) + + def test_debug_sharding_flag(self): + """debug_sharding=True should execute the sharding-print path without error.""" + cfg = _make_config(debug_sharding=True) + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + # ---- checkpoint loading: mocked paths ---- + + def _make_linen_metadata_mock(self): + """Mock ocp metadata that looks like a Linen checkpoint.""" + meta = MagicMock() + meta.item_metadata.tree.keys.return_value = ["params"] + meta.item_metadata.tree.get.return_value = {"params": {}} + return meta + + def _make_nnx_metadata_mock(self): + """Mock ocp metadata that looks like an NNX checkpoint.""" + meta = MagicMock() + meta.item_metadata.tree.keys.return_value = ["decoder"] + meta.item_metadata.tree.get.return_value = {} + return meta + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_load_nnx_checkpoint(self, mock_ocp): + """NNX-format checkpoint: restored values are wrapped under a 'value' key.""" + # Echo back the `item` argument passed by create_nnx_model to ckptr.restore. + # For NNX checkpoints, item IS already {leaf: {"value": array}, ...}, so + # returning it directly gives a correctly-structured restored dict that + # matches the model's own state — regardless of the exact leaf count. + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/nnx_ckpt") + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_load_linen_checkpoint(self, mock_ocp): + """Linen-format checkpoint: restored values are nested under 'params'/'params'.""" + # Echo back the `item` argument passed by create_nnx_model to ckptr.restore. + # For Linen checkpoints, item IS already {"params": {"params": arrays}}, so + # returning it directly gives a correctly-structured restored dict that + # matches the model's own state — regardless of the exact leaf count. + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/linen_ckpt") + model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + @patch("maxtext.utils.model_creation_utils.ocp") + def test_checkpoint_load_error_raises_value_error(self, mock_ocp): + """Any exception during checkpoint loading should be re-raised as ValueError.""" + mock_ckptr = MagicMock() + mock_ckptr.metadata.side_effect = RuntimeError("disk on fire") + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + + cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/bad_ckpt") + with self.assertRaises(ValueError): + model_creation_utils.create_nnx_model(cfg, self.mesh) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index c3a43970ba..504d8fab4b 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,17 +15,18 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax +import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -47,6 +48,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -55,6 +57,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -71,8 +74,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -97,6 +98,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -105,6 +107,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -121,8 +124,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -241,7 +242,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True) self.assertEqual(actual_output, expected_output) @@ -428,5 +429,105 @@ def learning_rate_schedule(step): self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0)) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/pipeline_parallelism_test.py b/tests/unit/pipeline_parallelism_test.py index b2582d822c..2506adb35a 100644 --- a/tests/unit/pipeline_parallelism_test.py +++ b/tests/unit/pipeline_parallelism_test.py @@ -118,9 +118,7 @@ def get_inputs(batch_size, sequence, features): single_pipeline_stage = simple_layer.SimpleDecoderLayerToLinen( config=config, mesh=mesh, model_mode=model_mode, rngs=rngs ) - my_pipeline = pipeline.create_pipeline( - config=config, layers=single_pipeline_stage, mesh=mesh - ) + my_pipeline = pipeline.create_pipeline(config=config, layers=single_pipeline_stage, mesh=mesh) init_pipeline_params = my_pipeline.init( jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode ) @@ -351,33 +349,33 @@ def test_full_train_circular(self): def test_full_train_circular_pipeline_ag_per_repeat(self): # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), # 8 microbatches and using pipeline ag per repeat - train_main([ - None, - get_test_config_path(), - f"base_output_directory={self.base_output_directory}", - "run_name=runner_pipeline_parallelism_test", - f"dataset_path={self.dataset_path}", - "base_emb_dim=28", - "base_num_query_heads=4", - "base_num_kv_heads=4", - "base_mlp_dim=32", - "base_num_decoder_layers=32", - "head_dim=128", - "per_device_batch_size=2", - "max_target_length=1024", - "vocab_size=32", - "dataset_type=synthetic", - "steps=3", - "enable_checkpointing=False", - "enable_goodput_recording=False", - "ici_pipeline_parallelism=2", - "num_layers_per_pipeline_stage=1", - "num_pipeline_microbatches=4", - "pipeline_fsdp_ag_per_repeat=True", - ( - rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}" - ), - ]) + train_main( + [ + None, + get_test_config_path(), + f"base_output_directory={self.base_output_directory}", + "run_name=runner_pipeline_parallelism_test", + f"dataset_path={self.dataset_path}", + "base_emb_dim=28", + "base_num_query_heads=4", + "base_num_kv_heads=4", + "base_mlp_dim=32", + "base_num_decoder_layers=32", + "head_dim=128", + "per_device_batch_size=2", + "max_target_length=1024", + "vocab_size=32", + "dataset_type=synthetic", + "steps=3", + "enable_checkpointing=False", + "enable_goodput_recording=False", + "ici_pipeline_parallelism=2", + "num_layers_per_pipeline_stage=1", + "num_pipeline_microbatches=4", + "pipeline_fsdp_ag_per_repeat=True", + (rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}"), + ] + ) @pytest.mark.tpu_only def test_delay_activation_forwarding_same_output_and_grad(self): @@ -492,6 +490,9 @@ def test_full_train_fp8(self): "quantization=fp8", "scan_layers_per_stage=False", "attention=dot_product", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] _adapt_parallelism(args, pipeline_stages=4) train_main(args) @@ -525,6 +526,9 @@ def test_full_train_nanoo_fp8(self): "quantization=nanoo_fp8", "scan_layers_per_stage=False", "attention=dot_product", + "pure_nnx=False", + "enable_nnx=False", + "pure_nnx_decoder=False", ] _adapt_parallelism(args, pipeline_stages=4) train_main(args) diff --git a/tests/unit/sharding_compare_test.py b/tests/unit/sharding_compare_test.py index 2cd696f241..fe3b7c6386 100644 --- a/tests/unit/sharding_compare_test.py +++ b/tests/unit/sharding_compare_test.py @@ -14,13 +14,16 @@ """Compare expected sharding of models with actual sharding of models.""" +import functools import hashlib import json import os import jax import jax.numpy as jnp +from flax import nnx from maxtext.configs import pyconfig -from maxtext.utils import maxtext_utils +from maxtext.layers.train_state_nnx import TrainStateNNX +from maxtext.utils import maxtext_utils, maxtext_utils_nnx, model_creation_utils from maxtext.utils.sharding import clear_input_shardings_dump # import optax @@ -221,20 +224,33 @@ def abstract_state_and_shardings(request): topology_mesh = get_topology_mesh(config) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=topology_mesh, quant=quant) - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) - # tx = optax.adam(learning_rate=learning_rate_schedule) tx = optimizers.get_optimizer(config, learning_rate_schedule) - rng = jax.random.PRNGKey(0) + + if config.pure_nnx: + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn + else: + model = Transformer(config, mesh=topology_mesh, quant=quant) + rng = jax.random.PRNGKey(0) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) # Get abstract state and physical shardings from maxtext_utils abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - model, tx, config, rng, topology_mesh, is_training=True + config, topology_mesh, init_state_fn, is_training=True ) - # Get logical shardings from maxtext_utils - logical_shardings = maxtext_utils.get_logical_annotations(model, tx, config, rng, topology_mesh, is_training=True) + # Get logical shardings + if config.pure_nnx: + logical_shardings = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + else: + logical_shardings = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings @@ -252,11 +268,23 @@ def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pyl abstract_state_and_shardings ) - assert hasattr(abstract_state, "params") - assert hasattr(abstract_state, "opt_state") - param_leaf = jax.tree_util.tree_leaves(abstract_state.params)[0] - assert isinstance(param_leaf, jax.ShapeDtypeStruct) - assert param_leaf.dtype == jnp.float32 + if hasattr(abstract_state, "params"): # Linen TrainState + assert hasattr(abstract_state, "opt_state") + param_leaf = jax.tree_util.tree_leaves(abstract_state.params)[0] + assert isinstance(param_leaf, jax.ShapeDtypeStruct) + assert param_leaf.dtype == jnp.float32 + else: # NNX nnx.State + assert hasattr(abstract_state, "model") + assert hasattr(abstract_state, "optimizer") + # Filter to floating-point leaves only: abstract_state.model also contains + # RNG state variables (uint32 / key dtype) which are not weight parameters. + float_leaves = [ + l + for l in jax.tree_util.tree_leaves(abstract_state.model) + if isinstance(l, jax.ShapeDtypeStruct) and jnp.issubdtype(l.dtype, jnp.floating) + ] + assert len(float_leaves) > 0 + assert all(l.dtype == jnp.float32 for l in float_leaves) root_dir = "tests/utils/sharding_info" # Or your target directory base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}") diff --git a/tests/unit/state_dtypes_test.py b/tests/unit/state_dtypes_test.py index 77e166193a..a251b0865d 100644 --- a/tests/unit/state_dtypes_test.py +++ b/tests/unit/state_dtypes_test.py @@ -13,17 +13,20 @@ # limitations under the License. """ Test that all weights are expected dtype (default float32) """ +from functools import partial import unittest import jax import jax.numpy as jnp +from flax import nnx from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations +from maxtext.layers.train_state_nnx import TrainStateNNX from maxtext.models import models from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides Transformer = models.transformer_as_linen @@ -34,27 +37,42 @@ class StateDtypes(unittest.TestCase): def get_state(self, argv): """Gets model state including weights and optimizer state""" - # Conditionally set ici_fsdp_parallelism to match device count in decoupled mode argv = list(argv) + get_decoupled_parallelism_overrides(as_argv=True) - - # Setup necessary inputs to build a model state config = pyconfig.initialize(argv) quant = quantizations.configure_quantization(config) devices_array = maxtext_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - model = Transformer(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, example_rng, mesh) + if config.pure_nnx: + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return TrainStateNNX(nnx_model, optimizer) + + return nnx.eval_shape(create_train_state_fn) + + model = Transformer(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + _, example_rng = jax.random.split(jax.random.PRNGKey(0), 2) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) + abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) return abstract_state def get_weights(self, argv): - return self.get_state(argv).params + state = self.get_state(argv) + if isinstance(state, TrainStateNNX): + _, param_state, _ = nnx.split(state, nnx.Param, ...) + return param_state + return state.params def get_mu(self, argv): - return self.get_state(argv).opt_state[0].mu + state = self.get_state(argv) + if isinstance(state, TrainStateNNX): + return state.optimizer.opt_state[0].mu + return state.opt_state[0].mu def assert_pytree_is_dtype(self, weights, expected_dtype): jax.tree_util.tree_map_with_path(lambda x, y: self.assertEqual(y.dtype, expected_dtype), weights) diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 58b688634d..9767384f6b 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -209,6 +209,8 @@ def test_vocab_tiling_gradient_with_z_loss(self): num_vocab_tiling=1, z_loss_multiplier=1e-4, # Enable z-loss ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -275,6 +277,8 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): matmul_precision="high", num_vocab_tiling=1, ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -339,6 +343,8 @@ def test_vocab_tiling_gradient_tied_embedding(self): matmul_precision="high", num_vocab_tiling=1, ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) @@ -401,6 +407,8 @@ def test_vocab_tiling_gradient_data_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -465,6 +473,8 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -531,6 +541,8 @@ def test_vocab_tiling_gradient_context_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if cfg_non_tiling.enable_nnx: + self.skipTest("Vocab tiling is not supported with NNX yet.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index cb291e13bd..add20e7c5a 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -196,6 +196,7 @@ def test_sequence_parallelism(self): "global_parameter_scale=32", "per_device_batch_size=0.0625", "max_target_length=65536", + "attention=flash", # Long seq requires flash; dot_product from decoupled config OOMs. ) ) @@ -309,6 +310,7 @@ def test_custom_64x4_mesh(self): "max_target_length=65536", "allow_split_physical_axes=true", "custom_mesh=hybrid_ring_64x4", + "attention=flash", # Long seq requires flash; dot_product from decoupled config OOMs. ) ) diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..53318469fa --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,291 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(state_linen)) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(linen_as_nnx_dict)) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(checkpoint_dict)) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_utils_test.py b/tests/unit/train_utils_test.py new file mode 100644 index 0000000000..a8b9458794 --- /dev/null +++ b/tests/unit/train_utils_test.py @@ -0,0 +1,196 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for train_utils.py.""" + +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock + +from maxtext.utils.train_utils import validate_train_config, create_training_optimizer + + +@dataclass +class MockConfig: + """Minimal mock config for validate_train_config tests.""" + + run_name: str = "test_run" + dataset_path: str = "gs://test-bucket/data" + base_output_directory: str = "gs://test-bucket/output" + steps: int = 100 + quantization: str = "" + gradient_accumulation_steps: int = 1 + packing: bool = False + dataset_type: str = "tfds" + + # Fields needed for create_training_optimizer + opt_type: str = "adamw" + adam_b1: float = 0.9 + adam_b2: float = 0.95 + adam_eps: float = 1e-8 + adam_eps_root: float = 0.0 + adam_weight_decay: float = 0.1 + mu_dtype: str = "" + learning_rate: float = 1e-4 + learning_rate_schedule_steps: int = 1000 + warmup_steps_fraction: float = 0.1 + cosine_learning_rate_final_fraction: float = 0.0 + steps: int = 100 + lr_schedule_type: str = "cosine" + use_iota_embed: bool = False + + +class TestValidateTrainConfig(unittest.TestCase): + """Tests for validate_train_config.""" + + def test_valid_config_passes(self): + """Verifies no exception raised for a valid config.""" + config = MockConfig() + # Should not raise + validate_train_config(config) + + def test_missing_run_name_raises(self): + """Verifies AssertionError when run_name is empty.""" + config = MockConfig(run_name="") + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_zero_steps_raises(self): + """Verifies AssertionError when steps is 0.""" + config = MockConfig(steps=0) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_negative_steps_raises(self): + """Verifies AssertionError when steps is negative.""" + config = MockConfig(steps=-5) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=2) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_nanoo_fp8_with_grad_accumulation_raises(self): + """Verifies AssertionError for nanoo_fp8 quantization + gradient_accumulation_steps > 1.""" + config = MockConfig(quantization="nanoo_fp8", gradient_accumulation_steps=4) + with self.assertRaises(AssertionError): + validate_train_config(config) + + def test_fp8_with_single_grad_accumulation_passes(self): + """Verifies no error for fp8 with gradient_accumulation_steps=1.""" + config = MockConfig(quantization="fp8", gradient_accumulation_steps=1) + validate_train_config(config) # Should not raise + + def test_packing_with_synthetic_data_logs_warning(self): + """Verifies no exception for packing + synthetic (just logs a warning).""" + config = MockConfig(packing=True, dataset_type="synthetic") + # Should not raise - just log a warning + validate_train_config(config) + + def test_local_dataset_path_logs_warning(self): + """Verifies no exception for local dataset_path (just logs a warning).""" + config = MockConfig(dataset_path="/local/path/to/data") + validate_train_config(config) # Should not raise + + def test_local_output_directory_logs_warning(self): + """Verifies no exception for local base_output_directory (just logs a warning).""" + config = MockConfig(base_output_directory="/local/output") + validate_train_config(config) # Should not raise + + +class TestCreateTrainingOptimizer(unittest.TestCase): + """Tests for create_training_optimizer.""" + + def _make_config(self, opt_type="adamw", **kwargs): + """Creates a mock config for optimizer tests.""" + cfg = MockConfig(opt_type=opt_type, **kwargs) + return cfg + + def _mock_lr_schedule(self): + """Returns a mock learning rate schedule that returns a fixed value.""" + return lambda step: 1e-4 + + def test_adamw_optimizer_returns_schedule_and_tx(self): + """Verifies create_training_optimizer returns a schedule and optax transform for adamw.""" + config = MagicMock() + config.opt_type = "adamw" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + schedule, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(schedule) + self.assertIsNotNone(tx) + # Verify it's an optax GradientTransformation + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_adam_pax_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for adam_pax optimizer.""" + config = MagicMock() + config.opt_type = "adam_pax" + config.adam_b1 = 0.9 + config.adam_b2 = 0.999 + config.adam_eps = 1e-8 + config.adam_eps_root = 0.0 + config.adam_weight_decay = 0.01 + config.mu_dtype = None + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.1 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + self.assertTrue(hasattr(tx, "update")) + + def test_sgd_optimizer_returns_tx(self): + """Verifies create_training_optimizer works for sgd optimizer.""" + config = MagicMock() + config.opt_type = "sgd" + config.learning_rate = 1e-4 + config.warmup_steps_fraction = 0.0 + config.cosine_learning_rate_final_fraction = 0.0 + config.steps = 100 + config.learning_rate_schedule_steps = 100 + config.lr_schedule_type = "cosine" + config.use_iota_embed = False + + _, tx = create_training_optimizer(config, model=None) + + self.assertIsNotNone(tx) + self.assertTrue(hasattr(tx, "init")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index 4d6f9982b6..be9fb05049 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -37,6 +37,7 @@ """Check if the logits generated by a model's src/maxtext/HF implementation matches golden logits for the same inputs""" import argparse +import functools import os from pathlib import Path import sys @@ -251,8 +252,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - state, _ = maxtext_utils.setup_decode_state(model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng1) + state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) if test_args.golden_logits_path == "": input_golden_data_path = os.path.join( @@ -435,8 +441,13 @@ def main(config, test_args): # pylint: disable=W0621 devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - maxtext_state, _ = maxtext_utils.setup_decode_state(maxtext_model, config, rng1, mesh, None) + if config.pure_nnx: + # NNX has a different function to init the training state. + raise NotImplementedError("Pure NNX support has not been implemented yet.") + else: + maxtext_model = models.transformer_as_linen(config, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, maxtext_model, None, config, False, rng1) + maxtext_state, _ = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn) prompts = ["I love to", "Today is a", "What is the"] all_data_to_save = [] diff --git a/tests/utils/run_sharding_dump.py b/tests/utils/run_sharding_dump.py index e1ad7fbba6..6d01a7c5f5 100644 --- a/tests/utils/run_sharding_dump.py +++ b/tests/utils/run_sharding_dump.py @@ -58,25 +58,26 @@ flags.DEFINE_string("model_name", None, "Specific model name to dump.") flags.DEFINE_string("topology", None, "Specific topology to dump.") flags.DEFINE_string("num_slice", None, "Specific number of slices to dump.") +flags.DEFINE_bool("pure_nnx", True, "Use pure NNX model.") -def run_single_dump(model_name: str, topology: str, num_slice: str) -> None: +def run_single_dump(model_name: str, topology: str, num_slice: str, pure_nnx: bool = True) -> None: """Generate sharding json file for one specific model, topology and slice.""" - subprocess.run( - [ - "python3", - "-m", - "tests.utils.sharding_dump", - get_test_config_path(), - f"compile_topology={topology}", - f"compile_topology_num_slices={num_slice}", - f"model_name={model_name}", - "weight_dtype=float32", - "log_config=false", - "debug_sharding=true", - ], - check=True, - ) + cmd = [ + "python3", + "-m", + "tests.utils.sharding_dump", + get_test_config_path(), + f"compile_topology={topology}", + f"compile_topology_num_slices={num_slice}", + f"model_name={model_name}", + "weight_dtype=float32", + "log_config=false", + "debug_sharding=true", + ] + if pure_nnx: + cmd.append("pure_nnx=true") + subprocess.run(cmd, check=True) def main(argv: Sequence[str]) -> None: @@ -106,7 +107,7 @@ def main(argv: Sequence[str]) -> None: print(" -> Sharding files already exist. Regenerating to overwrite.") try: - run_single_dump(model_name, topology, str(num_slice)) + run_single_dump(model_name, topology, str(num_slice), pure_nnx=FLAGS.pure_nnx) except subprocess.CalledProcessError: print(f"!!! FAILED: {model_name} {topology} {num_slice}") diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json index cbee49b201..0f4beb7f8d 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attention_mla/out: bfloat16[192,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json index 1fd6ceb6fd..fa3d861024 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_1/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json index a12030dbd9..17c11bb90d 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attention_mla/out: bfloat16[768,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json index 5b2ab94daf..5a520d4fcb 100644 --- a/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/tpu7x-16/slice_4/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json index 4172fc960f..ccbc1ce7fc 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[96,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attention_mla/out: bfloat16[96,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json index 72cbbdea66..fccbc77d0a 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_1/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json index 2789aa367e..45fece06ae 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[384,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attention_mla/out: bfloat16[384,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json index 65120bac91..28830f6cf0 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v5p-16/slice_4/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json index cbee49b201..0f4beb7f8d 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attention_mla/out: bfloat16[192,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json index 1fd6ceb6fd..fa3d861024 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_1/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json index a12030dbd9..17c11bb90d 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/input_shardings.json @@ -84,6 +84,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attention_mla/out: bfloat16[768,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json index 5e33fc22b8..7cbe66953c 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/logical_shardings.json @@ -1,21 +1,102 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -23,11 +104,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -35,11 +126,21 @@ 10944 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -47,42 +148,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -91,12 +247,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -105,11 +271,16 @@ 192 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -117,12 +288,22 @@ 576 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -131,20 +312,59 @@ 256 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -153,12 +373,57 @@ 64 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -167,12 +432,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -181,12 +455,21 @@ 1408 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -195,11 +478,57 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -207,11 +536,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -219,11 +558,21 @@ 2816 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -231,42 +580,97 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -275,12 +679,22 @@ 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -289,11 +703,16 @@ 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -301,12 +720,22 @@ 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -315,33 +744,80 @@ 256 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -349,11 +825,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -361,11 +847,21 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -373,42 +869,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -417,12 +932,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -431,11 +956,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -443,12 +973,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -457,20 +997,35 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -479,12 +1034,21 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -493,12 +1057,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -507,12 +1080,21 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -521,11 +1103,21 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -533,11 +1125,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -545,11 +1147,21 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -557,42 +1169,61 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -601,12 +1232,22 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -615,11 +1256,16 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -627,12 +1273,22 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -641,29 +1297,52 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -671,11 +1350,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -683,11 +1372,21 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "dense_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 10944, @@ -695,42 +1394,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "dense_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "dense_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -739,12 +1457,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "dense_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -753,11 +1481,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "dense_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -765,12 +1498,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "dense_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -779,20 +1522,35 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed_moe", - "moe_layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, null ], "shape": [ @@ -801,12 +1559,21 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -815,12 +1582,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "embed_no_exp_moe", - "mlp" + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 64, @@ -829,12 +1605,21 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "partition_spec": [ - "exp", - "moe_layers", - "mlp", - "embed_no_exp_moe" + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 64, @@ -843,11 +1628,21 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -855,11 +1650,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, @@ -867,11 +1672,21 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "moe_layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 2816, @@ -879,42 +1694,61 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 2048, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "moe_layers" + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ 512, 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "moe_layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -923,12 +1757,22 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "moe_layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2048, @@ -937,11 +1781,16 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "partition_spec": [ - "embed", - "moe_layers", - "kv_lora_up_proj" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null ], "shape": [ 2048, @@ -949,12 +1798,22 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { - "partition_spec": [ - "kv_lora", - "moe_layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 512, @@ -963,17 +1822,31 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 102400, 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json index 5b2ab94daf..5a520d4fcb 100644 --- a/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/deepseek2-16b/v6e-16/slice_4/named_shardings.json @@ -1,5 +1,5 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -32,10 +32,17 @@ "autoregressive": 1 } }, - "partition_spec": [], - "shape": [] + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2048 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +75,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -111,29 +113,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -166,29 +151,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 10944 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -221,29 +189,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 10944, - 1, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -276,19 +227,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -321,19 +265,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -366,19 +303,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], + "partition_spec": [], "shape": [ - 512, 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -411,31 +341,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "partition_spec": [], "shape": [ - 16, - 1, - 128, - 2048 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +379,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 16, - 192 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,25 +417,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], + "partition_spec": [], "shape": [ - 2048, - 1, - 576 + 1 ] }, - ".params/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -576,31 +455,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 512, - 1, - 16, - 256 + 1 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -633,27 +493,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], + "partition_spec": [], "shape": [ - 2048, - 102400 + 1 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -689,22 +534,25 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], null, - null + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2048, - 26, - 64 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -738,14 +586,13 @@ } }, "partition_spec": [ - "expert", - null, [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], + null, [ "fsdp_transpose", "tensor", @@ -754,13 +601,12 @@ ] ], "shape": [ - 64, - 26, 2048, - 1408 + 1, + 10944 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['model']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -794,29 +640,1935 @@ } }, "partition_spec": [ - "expert", - null, - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ "fsdp_transpose", "tensor", "tensor_sequence", "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" ] ], "shape": [ - 64, + 10944, + 1, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 512, + 1 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 16, + 1, + 128, + 2048 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2048, + 1, + 16, + 192 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 1, + 576 + ] + }, + "['model']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 512, + 1, + 16, + 256 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, + 102400 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + null + ], + "shape": [ + 2048, + 26, + 64 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 64, + 26, + 2048, + 1408 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] + ], + "shape": [ + 64, + 26, + 1408, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 2048, 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ 2048, - 1408 + 26, + 2816 + ] + }, + "['model']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 2816, + 26, + 2048 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['model']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -850,29 +2602,108 @@ } }, "partition_spec": [ - "expert", - null, [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", + "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ [ + "tensor", + "tensor_transpose" + ], + null + ], + "shape": [ + 2048, + 26 + ] + }, + "['model']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", + "fsdp_transpose", "sequence", + "context", + "context_autoregressive", + "tensor", "tensor_transpose", - "context" - ] + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + null ], "shape": [ - 64, - 26, - 1408, - 2048 + 512, + 26 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -906,28 +2737,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null, + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ], - null, - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" ] ], "shape": [ - 2048, + 16, 26, - 2816 + 128, + 2048 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -964,25 +2796,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], null, [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 2048, 26, - 2816 + 16, + 192 ] }, - ".params/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1016,28 +2849,22 @@ } }, "partition_spec": [ - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" - ] + ], + null, + null ], "shape": [ - 2816, + 2048, 26, - 2048 + 576 ] }, - ".params/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1071,18 +2898,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + null, [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], null ], "shape": [ - 2048, - 26 + 512, + 26, + 16, + 256 ] }, - ".params/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1115,19 +2953,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 2048, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1160,19 +2989,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - null - ], - "shape": [ - 512, - 26 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1205,31 +3025,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null, - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 26, - 128, - 2048 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1262,31 +3061,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2048, - 26, - 16, - 192 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1319,25 +3097,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - null, - null - ], - "shape": [ - 2048, - 26, - 576 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1370,31 +3133,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - null, - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 512, - 26, - 16, - 256 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1436,7 +3178,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1447,7 +3188,7 @@ 2048 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1483,7 +3224,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1526,7 +3267,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1563,7 +3304,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1581,7 +3321,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1618,7 +3358,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1636,7 +3375,7 @@ 10944 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1680,7 +3419,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1691,7 +3429,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1736,7 +3474,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1781,7 +3519,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1826,7 +3564,7 @@ 1 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1870,7 +3608,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1883,7 +3620,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1919,7 +3656,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1940,7 +3676,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1976,9 +3712,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1991,7 +3725,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2027,7 +3761,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2048,7 +3781,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2084,7 +3817,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2101,7 +3833,7 @@ 102400 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2137,9 +3869,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2152,7 +3882,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2191,7 +3921,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2208,7 +3937,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2247,7 +3976,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2264,7 +3992,7 @@ 1408 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2309,7 +4037,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -2320,7 +4047,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2357,7 +4084,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2375,7 +4101,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2412,7 +4138,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2430,7 +4155,7 @@ 2816 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2474,7 +4199,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -2485,7 +4209,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2530,7 +4254,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2575,7 +4299,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2620,7 +4344,7 @@ 26 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2664,7 +4388,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2677,7 +4400,7 @@ 2048 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2713,7 +4436,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2734,7 +4456,7 @@ 192 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2770,9 +4492,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2785,7 +4505,7 @@ 576 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2821,7 +4541,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2842,7 +4561,7 @@ 256 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2884,7 +4603,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2895,7 +4613,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2938,7 +4656,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2975,7 +4693,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2993,7 +4710,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3030,7 +4747,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3048,7 +4764,7 @@ 10944 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3092,7 +4808,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3103,7 +4818,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3148,7 +4863,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3193,7 +4908,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3238,7 +4953,7 @@ 1 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3282,7 +4997,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3295,7 +5009,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3331,7 +5045,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3352,7 +5065,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3388,9 +5101,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3403,7 +5114,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['dense_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3439,7 +5150,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3460,7 +5170,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3496,7 +5206,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3513,7 +5222,7 @@ 102400 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3549,9 +5258,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3564,7 +5271,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3603,7 +5310,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3620,7 +5326,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3659,7 +5365,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3676,7 +5381,7 @@ 1408 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['MoeBlock_0']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3721,7 +5426,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3732,7 +5436,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3769,7 +5473,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3787,7 +5490,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3824,7 +5527,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3842,7 +5544,7 @@ 2816 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['DeepSeekMoeBlock_0']/['shared_experts']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3886,7 +5588,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -3897,7 +5598,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3942,7 +5643,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3987,7 +5688,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['kv_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4032,7 +5733,7 @@ 26 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4076,7 +5777,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4089,7 +5789,7 @@ 2048 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4125,7 +5825,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4146,7 +5845,7 @@ 192 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_a']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4182,9 +5881,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4197,7 +5894,7 @@ 576 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['moe_layers']/['self_attention']/['wkv_b']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4233,7 +5930,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4254,7 +5950,7 @@ 256 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4296,7 +5992,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4307,7 +6002,43 @@ 2048 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json index 1f050c09b8..486f0c2dea 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[192,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/named_shardings.json index 78e42a8848..fe71b32d5e 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_1/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json index 96fab6247a..328efa9e99 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[768,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/named_shardings.json index ed765f1d18..f8c5b62786 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/tpu7x-16/slice_4/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json index ab45563642..293de91552 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[96,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[96,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/named_shardings.json index 8d8089aac3..c1140e19c2 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_1/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json index 0a86cb5c83..493aa85c60 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[384,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[384,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/named_shardings.json index a395dba2ea..5685819da5 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v5p-16/slice_4/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json index 1f050c09b8..486f0c2dea 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[192,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/named_shardings.json index 78e42a8848..fe71b32d5e 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_1/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json index 96fab6247a..328efa9e99 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[768,2048,64,64]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json index 35b79ae83c..44bbaec1c8 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/logical_shardings.json @@ -1,21 +1,85 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -23,12 +87,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -37,22 +111,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -61,11 +150,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -73,12 +167,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -87,21 +191,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -109,12 +218,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -123,20 +242,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -145,12 +269,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -159,11 +292,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -171,12 +308,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -185,11 +331,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -197,12 +347,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -211,11 +370,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -223,31 +385,78 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -255,12 +464,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -269,22 +488,37 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -293,11 +527,16 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -305,12 +544,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -319,21 +568,26 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -341,12 +595,22 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -355,20 +619,25 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -377,12 +646,21 @@ 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -391,11 +669,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -403,12 +685,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -417,11 +708,15 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -429,12 +724,21 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -443,11 +747,14 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -455,63 +762,157 @@ 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -519,12 +920,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -533,22 +944,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -557,11 +983,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -569,12 +1000,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -583,21 +1024,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -605,12 +1051,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -619,20 +1075,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -641,12 +1102,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -655,11 +1125,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -667,12 +1141,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -681,11 +1164,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -693,12 +1180,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -707,11 +1203,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -719,31 +1218,42 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -751,12 +1261,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -765,22 +1285,37 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -789,11 +1324,16 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -801,12 +1341,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -815,21 +1365,26 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -837,12 +1392,22 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -851,20 +1416,25 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -873,12 +1443,21 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -887,11 +1466,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -899,12 +1482,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -913,11 +1505,15 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -925,12 +1521,21 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -939,11 +1544,14 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -951,59 +1559,93 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1011,12 +1653,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1025,22 +1677,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1049,11 +1716,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1061,12 +1733,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1075,21 +1757,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1097,12 +1784,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1111,20 +1808,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1133,12 +1835,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1147,11 +1858,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1159,12 +1874,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1173,11 +1897,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1185,12 +1913,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1199,11 +1936,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1211,31 +1951,42 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1243,12 +1994,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1257,22 +2018,37 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "partition_spec": [ - "embed", - "layers" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 64, @@ -1281,11 +2057,16 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "partition_spec": [ - "q_heads", - "layers", - "kv" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 64, @@ -1293,12 +2074,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1307,21 +2098,26 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 64, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "partition_spec": [ - "kv_heads", - "layers", - "kv_head_dim" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ 8, @@ -1329,12 +2125,22 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { - "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 2880, @@ -1343,20 +2149,25 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "partition_spec": [ null, - "layers" + "stage" ], "shape": [ 32, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", null ], "shape": [ @@ -1365,12 +2176,21 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1379,11 +2199,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1391,12 +2215,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "partition_spec": [ - "exp", - "layers", - "embed_no_exp_moe", - "mlp" + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 32, @@ -1405,11 +2238,15 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_mlp" + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] ], "shape": [ 32, @@ -1417,12 +2254,21 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "partition_spec": [ - "exp", - "layers", - "mlp", - "embed_no_exp_moe" + "expert", + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ 32, @@ -1431,11 +2277,14 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "partition_spec": [ - "exp", - "layers", - "activation_embed_moe" + "expert", + "stage", + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 32, @@ -1443,47 +2292,77 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "partition_spec": [ - "embed", - "vocab" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 2880, 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 201088, 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/named_shardings.json b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/named_shardings.json index ed765f1d18..f8c5b62786 100644 --- a/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/gpt-oss-20b/v6e-16/slice_4/named_shardings.json @@ -1,5 +1,228 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 2880 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -35,7 +258,944 @@ "partition_spec": [], "shape": [] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage" + ], + "shape": [ + 2880, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 64, + 12, + 64, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 64, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 64, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 64, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null + ], + "shape": [ + 8, + 12, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null + ], + "shape": [ + 2880, + 12, + 8, + 64 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + null, + "stage" + ], + "shape": [ + 32, + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + null + ], + "shape": [ + 2880, + 12, + 32 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence" + ] + ], + "shape": [ + 32, + 12, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + "expert", + "stage", + [ + "fsdp", + "sequence", + "context" + ], + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 32, + 12, + 2880, + 2880 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -69,16 +1229,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence" ] ], "shape": [ + 32, + 12, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -112,22 +1277,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null + [ + "fsdp", + "sequence", + "context" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -161,30 +1332,20 @@ } }, "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], + "expert", "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null + "tensor_transpose" + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -217,23 +1378,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage" - ], + "partition_spec": [], "shape": [ - 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -266,31 +1416,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -323,23 +1492,50 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - "stage", - null - ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], "shape": [ - 64, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_0']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -372,31 +1568,57 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [ + 12 + ] + }, + "['model']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -430,15 +1652,18 @@ } }, "partition_spec": [ - null, + [ + "tensor", + "tensor_transpose" + ], "stage" ], "shape": [ - 64, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -487,7 +1712,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -523,7 +1748,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -544,7 +1768,7 @@ 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -578,15 +1802,20 @@ } }, "partition_spec": [ - null, + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage" ], "shape": [ - 32, + 2880, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -620,24 +1849,29 @@ } }, "partition_spec": [ + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" - ], - "stage", - null + ] ], "shape": [ - 2880, + 64, 12, - 32 + 64, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -671,29 +1905,22 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + "stage", + null ], "shape": [ - 32, + 64, 12, - 2880, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -727,21 +1954,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ - 32, + 2880, 12, - 2880 + 64, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -775,29 +2010,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ] + null, + "stage" ], "shape": [ - 32, - 12, - 2880, - 2880 + 64, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -831,21 +2052,22 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", "tensor_transpose", - "tensor_sequence" - ] + "tensor_sequence", + "autoregressive" + ], + "stage", + null ], "shape": [ - 32, + 8, 12, - 2880 + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -879,29 +2101,29 @@ } }, "partition_spec": [ - "expert", + [ + "fsdp", + "sequence", + "context", + "expert" + ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] + null ], "shape": [ - 32, - 12, 2880, - 2880 + 12, + 8, + 64 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -935,20 +2157,15 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] + null, + "stage" ], "shape": [ 32, - 12, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -983,17 +2200,21 @@ }, "partition_spec": [ [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context", + "expert" ], - "stage" + "stage", + null ], "shape": [ 2880, - 12 + 12, + 32 ] }, - ".params/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1027,18 +2248,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "tensor", - "tensor_transpose" + "fsdp", + "sequence", + "context" ], - "stage" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ + 32, + 12, 2880, - 12 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1072,22 +2303,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_sequence" + ] ], "shape": [ - 8, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1121,30 +2351,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ], - "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" - ], - null + ] ], "shape": [ - 2880, + 32, 12, - 8, - 64 + 2880, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1178,22 +2406,21 @@ } }, "partition_spec": [ + "expert", + "stage", [ - "fsdp", - "fsdp_transpose", - "sequence", + "tensor", "tensor_transpose", - "context", - "expert" - ], - "stage" + "tensor_sequence" + ] ], "shape": [ - 2880, - 12 + 32, + 12, + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1227,30 +2454,28 @@ } }, "partition_spec": [ + "expert", + "stage", [ + "fsdp_transpose", "tensor", - "tensor_transpose", "tensor_sequence", "autoregressive" ], - "stage", - null, [ "fsdp", - "fsdp_transpose", "sequence", - "context", - "expert" + "context" ] ], "shape": [ - 64, + 32, 12, - 64, + 2880, 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1284,22 +2509,20 @@ } }, "partition_spec": [ + "expert", + "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null + "tensor_transpose" + ] ], "shape": [ - 64, + 32, 12, - 64 + 2880 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1332,31 +2555,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 64, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1389,16 +2593,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 64, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1431,23 +2631,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 8, - 12, - 64 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1480,31 +2669,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 2880, - 12, - 8, - 64 + "partition_spec": [], + "shape": [ + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1537,16 +2707,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - null, - "stage" - ], + "partition_spec": [], "shape": [ - 32, 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['model']/['decoder']/['layers']/['layers_1']/['dropout']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1579,25 +2745,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - null - ], + "partition_spec": [], "shape": [ - 2880, - 12, - 32 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['model']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1631,29 +2784,18 @@ } }, "partition_spec": [ - "expert", - "stage", - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ], [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, 2880, - 2880 + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['model']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1687,21 +2829,18 @@ } }, "partition_spec": [ - "expert", - "stage", [ "tensor", - "tensor_transpose", - "tensor_sequence" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 32, - 12, - 2880 + 2880, + 12 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['model']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1735,29 +2874,25 @@ } }, "partition_spec": [ - "expert", - "stage", [ "fsdp", "sequence", - "tensor_transpose", - "context" + "context", + "expert" ], [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ] ], "shape": [ - 32, - 12, 2880, - 2880 + 201088 ] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1790,22 +2925,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1838,30 +2961,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "fsdp_transpose", - "tensor", - "tensor_sequence", - "autoregressive" - ], - [ - "fsdp", - "sequence", - "tensor_transpose", - "context" - ] - ], - "shape": [ - 32, - 12, - 2880, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1894,21 +2997,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - "expert", - "stage", - [ - "tensor", - "tensor_transpose" - ] - ], - "shape": [ - 32, - 12, - 2880 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1941,19 +3033,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1986,19 +3069,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 2880, - 12 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['logits_dense']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2031,27 +3105,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ] - ], - "shape": [ - 2880, - 201088 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2093,7 +3150,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2104,7 +3160,7 @@ 2880 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2140,7 +3196,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2183,7 +3239,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2232,7 +3288,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2268,7 +3324,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2289,7 +3344,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2325,9 +3380,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2338,7 +3391,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2382,7 +3435,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2395,7 +3447,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2444,7 +3496,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2480,7 +3532,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2501,7 +3552,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2543,7 +3594,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2592,7 +3643,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2628,7 +3679,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2649,7 +3699,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2691,7 +3741,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2727,9 +3777,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -2742,7 +3790,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2781,7 +3829,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2798,7 +3845,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2846,7 +3893,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2885,7 +3932,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -2902,7 +3948,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2950,7 +3996,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2995,7 +4041,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3006,7 +4051,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3053,7 +4098,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3098,7 +4143,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3143,7 +4188,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3192,7 +4237,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3228,7 +4273,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3249,7 +4293,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3285,9 +4329,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3298,7 +4340,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3342,7 +4384,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3355,7 +4396,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3404,7 +4445,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3440,7 +4481,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3461,7 +4501,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3503,7 +4543,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3552,7 +4592,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3588,7 +4628,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -3609,7 +4648,7 @@ 64 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3651,7 +4690,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3687,9 +4726,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -3702,7 +4739,7 @@ 32 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3741,7 +4778,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3758,7 +4794,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3806,7 +4842,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3845,7 +4881,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -3862,7 +4897,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3910,7 +4945,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -3955,7 +4990,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -3966,7 +5000,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4013,7 +5047,7 @@ 2880 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4058,7 +5092,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4103,7 +5137,7 @@ 12 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4139,7 +5173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4156,7 +5189,7 @@ 201088 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4198,7 +5231,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4209,7 +5241,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4252,7 +5284,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4301,7 +5333,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4337,7 +5369,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4358,7 +5389,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4394,9 +5425,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4407,7 +5436,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4451,7 +5480,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4464,7 +5492,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4513,7 +5541,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4549,7 +5577,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4570,7 +5597,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4612,7 +5639,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4661,7 +5688,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4697,7 +5724,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -4718,7 +5744,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4760,7 +5786,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4796,9 +5822,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -4811,7 +5835,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4850,7 +5874,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4867,7 +5890,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4915,7 +5938,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -4954,7 +5977,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -4971,7 +5993,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5019,7 +6041,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5064,7 +6086,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -5075,7 +6096,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5122,7 +6143,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5167,7 +6188,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_0']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5212,7 +6233,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5261,7 +6282,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5297,7 +6318,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5318,7 +6338,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5354,9 +6374,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5367,7 +6385,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5411,7 +6429,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5424,7 +6441,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5473,7 +6490,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5509,7 +6526,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5530,7 +6546,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['sinks']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5572,7 +6588,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5621,7 +6637,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssAttention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5657,7 +6673,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -5678,7 +6693,7 @@ 64 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5720,7 +6735,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['gate']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5756,9 +6771,7 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", - "tensor_transpose", "context", "expert" ], @@ -5771,7 +6784,7 @@ 32 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5810,7 +6823,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5827,7 +6839,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_0_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5875,7 +6887,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5914,7 +6926,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ], [ @@ -5931,7 +6942,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wi_1_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -5979,7 +6990,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6024,7 +7035,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context" ] ], @@ -6035,7 +7045,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['GptOssMlp']/['wo_bias']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6082,7 +7092,7 @@ 2880 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6127,7 +7137,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['layers_1']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6172,7 +7182,7 @@ 12 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['logits_dense']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['logits_dense']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6208,7 +7218,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6225,7 +7234,7 @@ 201088 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -6267,7 +7276,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -6278,7 +7286,43 @@ 2880 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json index 0d5b2d8c24..48a289fcb6 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[192,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/named_shardings.json index 6208b4ba80..84f70e840d 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_1/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json index 2146f74797..2d9f407653 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[768,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/named_shardings.json index 31499e643e..661e39e64f 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/tpu7x-16/slice_4/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json index 4a5224cd6d..5b56cb2b17 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[96,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[96,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/named_shardings.json index 2cce1577f2..beee4d6788 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_1/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json index 6bb047297d..2398f4f8ca 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[384,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[384,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/named_shardings.json index b9512d15f0..8b15b25fac 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v5p-16/slice_4/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 8, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json index 0d5b2d8c24..48a289fcb6 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P('fsdp', None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[192,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P('fsdp', None)" + } + }, { "attentions/out: bfloat16[192,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/named_shardings.json index 6208b4ba80..84f70e840d 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_1/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 1, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json index 2146f74797..2d9f407653 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/input_shardings.json @@ -48,6 +48,12 @@ "PartitionSpec": "P(('data', 'fsdp'), None, None, None)" } }, + { + "attention_op/decoder_segment_ids: int32[768,2048]": { + "logic_axes": "Unknown", + "PartitionSpec": "P(('data', 'fsdp'), None)" + } + }, { "attentions/out: bfloat16[768,2048,16,128]": { "logic_axes": "('activation_batch', 'activation_attn_length_no_exp', 'activation_heads', 'activation_kv')", diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json index 487e9bb959..d23242f925 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/logical_shardings.json @@ -1,21 +1,90 @@ { - ".step": { - "partition_spec": [], - "shape": [] - }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -23,11 +92,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -35,11 +114,21 @@ 3072 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -47,32 +136,84 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -81,22 +222,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -105,12 +259,22 @@ 1024 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -119,22 +283,35 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -143,33 +320,80 @@ 128 ] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "partition_spec": [], + "shape": [] + }, + "['model']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -177,11 +401,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -189,11 +423,21 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -201,32 +445,48 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -235,22 +495,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -259,12 +532,22 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -273,22 +556,35 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -297,29 +593,52 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "partition_spec": [ - "norm" + [ + "tensor", + "tensor_transpose" + ] ], "shape": [ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -327,11 +646,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "mlp" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] ], "shape": [ 1024, @@ -339,11 +668,21 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "partition_spec": [ - "mlp", - "layers", - "embed" + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 3072, @@ -351,32 +690,48 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 1024, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -385,22 +740,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "partition_spec": [ - "heads", - "layers", - "kv", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + "stage", + null, + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 16, @@ -409,12 +777,22 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "q_heads", - "kv" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -423,22 +801,35 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "partition_spec": [ - "norm", - "layers" + [ + "tensor", + "tensor_transpose" + ], + "stage" ], "shape": [ 128, 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "partition_spec": [ - "embed", - "layers", - "kv_heads", - "kv_head_dim" + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + null ], "shape": [ 1024, @@ -447,17 +838,31 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "partition_spec": [ - "vocab", - "embed" + [ + "tensor", + "tensor_transpose", + "tensor_sequence", + "autoregressive" + ], + [ + "fsdp", + "sequence", + "context", + "expert" + ] ], "shape": [ 151936, 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "partition_spec": [], "shape": [] } diff --git a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/named_shardings.json b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/named_shardings.json index 31499e643e..661e39e64f 100644 --- a/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/named_shardings.json +++ b/tests/utils/sharding_info/qwen3-0.6b/v6e-16/slice_4/named_shardings.json @@ -1,5 +1,896 @@ { - ".step": { + "['model']/['decoder']/['decoder_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ] + ], + "shape": [ + 1024 + ] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['dropout']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ] + ], + "shape": [ + 1024, + 28, + 3072 + ] + }, + "['model']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "fsdp_transpose", + "tensor", + "tensor_sequence", + "autoregressive" + ], + "stage", + [ + "fsdp", + "sequence", + "context", + "expert" + ] + ], + "shape": [ + 3072, + 28, + 1024 + ] + }, + "['model']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [ + [ + "tensor", + "tensor_transpose" + ], + "stage" + ], + "shape": [ + 1024, + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['aqt']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['dropout']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [ + 28 + ] + }, + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -33,9 +924,11 @@ } }, "partition_spec": [], - "shape": [] + "shape": [ + 28 + ] }, - ".params/['params']/['decoder']/['decoder_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['attention_op']/['rngs']/['params']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -68,17 +961,12 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ] - ], + "partition_spec": [], "shape": [ - 1024 + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -115,25 +1003,26 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], "stage", [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" - ] + ], + null ], "shape": [ 1024, 28, - 3072 + 8, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -168,27 +1057,17 @@ }, "partition_spec": [ [ - "fsdp", - "sequence", - "tensor_transpose", - "context", - "expert" - ], - "stage", - [ - "fsdp_transpose", "tensor", - "tensor_sequence", - "autoregressive" - ] + "tensor_transpose" + ], + "stage" ], "shape": [ - 1024, - 28, - 3072 + 128, + 28 ] }, - ".params/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -223,27 +1102,28 @@ }, "partition_spec": [ [ - "fsdp_transpose", "tensor", + "tensor_transpose", "tensor_sequence", "autoregressive" ], "stage", + null, [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] ], "shape": [ - 3072, + 16, 28, + 128, 1024 ] }, - ".params/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -277,18 +1157,29 @@ } }, "partition_spec": [ + [ + "fsdp", + "sequence", + "context", + "expert" + ], + "stage", [ "tensor", - "tensor_transpose" + "tensor_transpose", + "tensor_sequence", + "autoregressive" ], - "stage" + null ], "shape": [ 1024, - 28 + 28, + 16, + 128 ] }, - ".params/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['model']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -329,11 +1220,11 @@ "stage" ], "shape": [ - 1024, + 128, 28 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['model']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -369,7 +1260,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -390,7 +1280,7 @@ 128 ] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['aqt']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -423,19 +1313,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['model']/['decoder']/['rngs']/['aqt']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -468,31 +1349,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - "stage", - null, - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ] - ], - "shape": [ - 16, - 28, - 128, - 1024 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['model']/['decoder']/['rngs']/['dropout']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -525,31 +1385,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "fsdp", - "fsdp_transpose", - "sequence", - "context", - "expert" - ], - "stage", - [ - "tensor", - "tensor_transpose", - "tensor_sequence", - "autoregressive" - ], - null - ], - "shape": [ - 1024, - 28, - 16, - 128 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['model']/['decoder']/['rngs']/['dropout']/['key']/.value": { "mesh": { "axis_names": [ "diloco", @@ -582,19 +1421,10 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ - "tensor", - "tensor_transpose" - ], - "stage" - ], - "shape": [ - 128, - 28 - ] + "partition_spec": [], + "shape": [] }, - ".params/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['model']/['decoder']/['rngs']/['params']/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -627,31 +1457,46 @@ "autoregressive": 1 } }, - "partition_spec": [ - [ + "partition_spec": [], + "shape": [] + }, + "['model']/['decoder']/['rngs']/['params']/['key']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", "fsdp", "fsdp_transpose", "sequence", "context", - "expert" - ], - "stage", - [ + "context_autoregressive", "tensor", "tensor_transpose", "tensor_sequence", + "expert", "autoregressive" ], - null - ], - "shape": [ - 1024, - 28, - 8, - 128 - ] + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] }, - ".params/['params']/['token_embedder']/['embedding']": { + "['model']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -693,7 +1538,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -704,7 +1548,7 @@ 1024 ] }, - ".opt_state/[0]/.count": { + "['optimizer']/['opt_state']/[0]/['count']/.value": { "mesh": { "axis_names": [ "diloco", @@ -740,7 +1584,7 @@ "partition_spec": [], "shape": [] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -783,7 +1627,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -820,7 +1664,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -838,7 +1681,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -875,7 +1718,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -893,7 +1735,7 @@ 3072 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -937,7 +1779,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -948,7 +1789,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -993,7 +1834,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1038,7 +1879,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1074,7 +1915,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1095,7 +1935,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1140,7 +1980,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1184,7 +2024,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1197,7 +2036,7 @@ 1024 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1233,7 +2072,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1254,7 +2092,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1299,7 +2137,7 @@ 28 ] }, - ".opt_state/[0]/.mu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['mu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1335,7 +2173,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1356,7 +2193,7 @@ 128 ] }, - ".opt_state/[0]/.mu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['mu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1398,7 +2235,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1409,7 +2245,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['decoder_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['decoder_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1452,7 +2288,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_0']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1489,7 +2325,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1507,7 +2342,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wi_1']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1544,7 +2379,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ], @@ -1562,7 +2396,7 @@ 3072 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['mlp']/['wo']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['mlp']/['wo']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1606,7 +2440,6 @@ [ "fsdp", "sequence", - "tensor_transpose", "context", "expert" ] @@ -1617,7 +2450,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['post_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1662,7 +2495,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['pre_self_attention_layer_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1707,7 +2540,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1743,7 +2576,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1764,7 +2596,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['key_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1809,7 +2641,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['out']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['out']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1853,7 +2685,6 @@ null, [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1866,7 +2697,7 @@ 1024 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1902,7 +2733,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -1923,7 +2753,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['query_norm']/['scale']/.value": { "mesh": { "axis_names": [ "diloco", @@ -1968,7 +2798,7 @@ 28 ] }, - ".opt_state/[0]/.nu/['params']/['decoder']/['layers']/['self_attention']/['value']/['kernel']": { + "['optimizer']/['opt_state']/[0]/['nu']/['decoder']/['layers']/['self_attention']/['value']/['kernel']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2004,7 +2834,6 @@ "partition_spec": [ [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2025,7 +2854,7 @@ 128 ] }, - ".opt_state/[0]/.nu/['params']/['token_embedder']/['embedding']": { + "['optimizer']/['opt_state']/[0]/['nu']/['token_embedder']/['embedding']/.value": { "mesh": { "axis_names": [ "diloco", @@ -2067,7 +2896,6 @@ ], [ "fsdp", - "fsdp_transpose", "sequence", "context", "expert" @@ -2078,7 +2906,43 @@ 1024 ] }, - ".opt_state/[2]/.count": { + "['optimizer']/['opt_state']/[2]/['count']/.value": { + "mesh": { + "axis_names": [ + "diloco", + "data", + "stage", + "fsdp", + "fsdp_transpose", + "sequence", + "context", + "context_autoregressive", + "tensor", + "tensor_transpose", + "tensor_sequence", + "expert", + "autoregressive" + ], + "shape": { + "diloco": 1, + "data": 4, + "stage": 1, + "fsdp": 16, + "fsdp_transpose": 1, + "sequence": 1, + "context": 1, + "context_autoregressive": 1, + "tensor": 1, + "tensor_transpose": 1, + "tensor_sequence": 1, + "expert": 1, + "autoregressive": 1 + } + }, + "partition_spec": [], + "shape": [] + }, + "['optimizer']/['step']/.value": { "mesh": { "axis_names": [ "diloco",