diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 35a08c105..1ea3f813c 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Enforce the array shape and type check during Array restoration when `ArrayRestoreArgs.strict` is set but shape/dtype is not provided. - On platforms where `uvloop` is not supported, fallback to `nest_asyncio`. +- #v1 Centralize `StorageOptions` into `ArrayOptions` and implement field-level +merging. ## [0.11.33] - 2026-02-17 diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py index 5f75fb1ff..6d6372aa2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py @@ -184,27 +184,6 @@ class PyTreeOptions: # TODO: Include an example of registering a custom LeafHandler. - Example: - To save certain leaves in float16, while others in float32, we can use - `create_array_storage_options_fn` like so:: - - import jax - import jax.numpy as jnp - from orbax.checkpoint.v1 import options as ocp_options - - def create_opts_fn(keypath, value): - if 'small' in jax.tree_util.keystr(keypath): - return ocp_options.ArrayOptions.Saving.StorageOptions( - dtype=jnp.float16 - ) - return ocp_options.ArrayOptions.Saving.StorageOptions(dtype=jnp.float32) - - pytree_options = ocp_options.PyTreeOptions( - saving=ocp_options.PyTreeOptions.Saving( - create_array_storage_options_fn=create_opts_fn - ) - ) - Attributes: saving: Options for saving PyTrees. loading: Options for loading PyTrees. @@ -216,25 +195,9 @@ def create_opts_fn(keypath, value): class Saving: """Options for saving PyTrees. - create_array_storage_options_fn: - A function that is called in order to create - :py:class:`.ArrayOptions.Saving.StorageOptions` for each leaf in a PyTree, - when it is - being saved. It is called similar to: - `jax.tree.map_with_path(create_array_storage_options_fn, pytree_to_save)`. - If provided, it overrides any default settings in - :py:class:`.ArrayOptions.Saving.StorageOptions`. pytree_metadata_options: Options for managing PyTree metadata. """ - class CreateArrayStorageOptionsFn(Protocol): - - def __call__( - self, key: tree_types.PyTreeKeyPath, value: Any - ) -> ArrayOptions.Saving.StorageOptions: - ... - - create_array_storage_options_fn: CreateArrayStorageOptionsFn | None = None pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = ( dataclasses.field(default_factory=tree_metadata.PyTreeMetadataOptions) ) @@ -265,7 +228,8 @@ class ArrayOptions: names during initialization. Example: - Configure array options with specific saving formats and loading behaviors:: + To configure array options with specific saving formats and loading + behaviors we can do so like this:: from orbax.checkpoint.v1.options import ArrayOptions @@ -280,6 +244,30 @@ class ArrayOptions: ) ) + To save certain leaves in float16, while others in float32, we can use + `scoped_storage_options_creator` like so:: + + import jax + import jax.numpy as jnp + from orbax.checkpoint.v1 import options as ocp_options + + def create_opts_fn(keypath, value): + if 'small' in jax.tree_util.keystr(keypath): + return ocp_options.ArrayOptions.Saving.StorageOptions( + dtype=jnp.float16 + ) + return None # Fall back to global `storage_options` + + array_options = ocp_options.ArrayOptions( + saving=ocp_options.ArrayOptions.Saving( + storage_options=ocp_options.ArrayOptions.Saving.StorageOptions( + dtype=jnp.float32 + ), + scoped_storage_options_creator=create_opts_fn + ) + + ) + Attributes: saving: Options for saving arrays. loading: Options for loading arrays. @@ -322,8 +310,24 @@ class Saving: True. array_metadata_store: Store to manage per host ArrayMetadata. To disable ArrayMetadata persistence, set it to None. + storage_options: Global default for array storage options. + scoped_storage_options_creator: A function that, when dealing with + PyTrees, is applied to every leaf. If it returns an + :py:class:`ArrayOptions.Saving.StorageOptions`, its fields take + precedence when merging if they are set to non-None or non-default + values with respect to `storage_options`. If it returns `None`, + `storage_options` is used as a default for all fields. It is called + similar to: `jax.tree.map_with_path(scoped_storage_options_creator, + pytree_to_save)`. """ + class ScopedStorageOptionsCreator(Protocol): + + def __call__( + self, key: tree_types.PyTreeKeyPath, value: Any + ) -> ArrayOptions.Saving.StorageOptions: + ... + @dataclasses.dataclass(frozen=True, kw_only=True) class StorageOptions: """Options used to customize array storage behavior for individual leaves. @@ -367,6 +371,7 @@ class StorageOptions: array_metadata_store: array_metadata_store_lib.Store | None = ( array_metadata_store_lib.Store() ) + scoped_storage_options_creator: ScopedStorageOptionsCreator | None = None @dataclasses.dataclass(frozen=True, kw_only=True) class Loading: diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 9cb8fd15f..08f8dafcc 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -29,6 +29,7 @@ from orbax.checkpoint._src.futures import synchronization from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint._src.serialization import types as v0_serialization_types from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.context import options as options_lib @@ -36,6 +37,7 @@ from orbax.checkpoint.experimental.v1._src.metadata import types as metadata_types from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.serialization import compatibility +from orbax.checkpoint.experimental.v1._src.serialization import options_resolution from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils from orbax.checkpoint.experimental.v1._src.serialization import registry from orbax.checkpoint.experimental.v1._src.serialization import scalar_leaf_handler @@ -69,32 +71,19 @@ def _get_remaining_timeout( def _get_v0_save_args( checkpointable: PyTree, - array_storage_options: options_lib.ArrayOptions.Saving.StorageOptions, - create_array_storage_options_fn: ( - options_lib.PyTreeOptions.Saving.CreateArrayStorageOptionsFn | None - ), + array_saving_options: options_lib.ArrayOptions.Saving, ) -> PyTree: """Returns save args that are compatible with the V0 API.""" - def _leaf_get_v0_save_args(k, v): - if create_array_storage_options_fn: - individual_array_storage_options = create_array_storage_options_fn(k, v) - save_dtype = ( - np.dtype(individual_array_storage_options.dtype) - if individual_array_storage_options.dtype - else None - ) - return v0_serialization_types.SaveArgs( - dtype=save_dtype, - chunk_byte_size=individual_array_storage_options.chunk_byte_size, - shard_axes=individual_array_storage_options.shard_axes, - ) - return v0_serialization_types.SaveArgs( - dtype=np.dtype(array_storage_options.dtype) - if array_storage_options.dtype + resolved_options = options_resolution.resolve_storage_options( + k, v, array_saving_options + ) + return type_handlers_v0.SaveArgs( + dtype=np.dtype(resolved_options.dtype) + if resolved_options.dtype is not None else None, - chunk_byte_size=array_storage_options.chunk_byte_size, - shard_axes=array_storage_options.shard_axes, + chunk_byte_size=resolved_options.chunk_byte_size, + shard_axes=resolved_options.shard_axes, ) return jax.tree.map_with_path(_leaf_get_v0_save_args, checkpointable) @@ -135,8 +124,7 @@ def create_v0_save_args( item=checkpointable, save_args=_get_v0_save_args( checkpointable, - context.array_options.saving.storage_options, - context.pytree_options.saving.create_array_storage_options_fn, + context.array_options.saving, ), ocdbt_target_data_file_size=context.array_options.saving.ocdbt_target_data_file_size, ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py index 3dfcc28a5..44347ca35 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py @@ -32,6 +32,8 @@ from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint.experimental.v1._src.context import context as context_lib +import orbax.checkpoint.experimental.v1._src.context.options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import options_resolution from orbax.checkpoint.experimental.v1._src.serialization import protocol_utils from orbax.checkpoint.experimental.v1._src.serialization import registration from orbax.checkpoint.experimental.v1._src.serialization import types @@ -109,18 +111,18 @@ def _create_v0_saving_paraminfo( def _create_v0_savearg( param: ArraySerializationParam, - context: context_lib.Context, + array_saving_options: options_lib.ArrayOptions.Saving, ) -> type_handlers_v0.SaveArgs: - """Creates a V0 `SaveArgs` from V1 params and context for saving.""" - fn = context.pytree_options.saving.create_array_storage_options_fn - if fn: - storage_options = fn(param.keypath, param.value) - else: - storage_options = context.array_options.saving.storage_options + """Creates a V0 `SaveArgs` from V1 params and array options for saving.""" + resolved_options = options_resolution.resolve_storage_options( + param.keypath, param.value, array_saving_options + ) return type_handlers_v0.SaveArgs( - dtype=jnp.dtype(storage_options.dtype) if storage_options.dtype else None, - chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + dtype=jnp.dtype(resolved_options.dtype) + if resolved_options.dtype is not None + else None, + chunk_byte_size=resolved_options.chunk_byte_size, + shard_axes=resolved_options.shard_axes, ) @@ -223,7 +225,10 @@ async def serialize( _create_v0_saving_paraminfo(p, self._context, serialization_context) for p in params ] - saveargs = [_create_v0_savearg(p, self._context) for p in params] + saveargs = [ + _create_v0_savearg(p, self._context.array_options.saving) + for p in params + ] commit_futures = await self._handler_impl.serialize( values, paraminfos, saveargs diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py index 9de580c6d..5b6c48ef7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/numpy_leaf_handler.py @@ -30,6 +30,8 @@ from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint.experimental.v1._src.context import context as context_lib +import orbax.checkpoint.experimental.v1._src.context.options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import options_resolution from orbax.checkpoint.experimental.v1._src.serialization import registration from orbax.checkpoint.experimental.v1._src.serialization import types @@ -96,18 +98,18 @@ def _create_v0_saving_paraminfo( def _create_v0_savearg( param: NumpySerializationParam, - context: context_lib.Context, + array_saving_options: options_lib.ArrayOptions.Saving, ) -> type_handlers_v0.SaveArgs: - """Creates a V0 `SaveArgs` from V1 params and context for saving.""" - fn = context.pytree_options.saving.create_array_storage_options_fn - if fn: - storage_options = fn(param.keypath, param.value) - else: - storage_options = context.array_options.saving.storage_options + """Creates a V0 `SaveArgs` from V1 params and array saving options.""" + resolved_options = options_resolution.resolve_storage_options( + param.keypath, param.value, array_saving_options + ) return type_handlers_v0.SaveArgs( - dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None, - chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + dtype=np.dtype(resolved_options.dtype) + if resolved_options.dtype is not None + else None, + chunk_byte_size=resolved_options.chunk_byte_size, + shard_axes=resolved_options.shard_axes, ) @@ -188,7 +190,10 @@ async def serialize( _create_v0_saving_paraminfo(p, self._context, serialization_context) for p in params ] - saveargs = [_create_v0_savearg(p, self._context) for p in params] + saveargs = [ + _create_v0_savearg(p, self._context.array_options.saving) + for p in params + ] commit_futures = await self._handler_impl.serialize( values, paraminfos, saveargs diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py new file mode 100644 index 000000000..96c12afd9 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution.py @@ -0,0 +1,77 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for serialization.""" + +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.tree import types as tree_types + + +def resolve_storage_options( + keypath: tree_types.PyTreeKeyPath, + value: tree_types.LeafType, + array_saving_options: options_lib.ArrayOptions.Saving, +) -> options_lib.ArrayOptions.Saving.StorageOptions: + """Resolves storage options using a global default and a per-leaf creator. + + When dealing with PyTrees, `scoped_storage_options_creator` is applied to + every leaf. Its fields take precedence when merging if they are set to + non-None or non-default values with respect to the global `storage_options`. + If the creator returns `None`, the global `storage_options` is used for all + fields. + + Args: + keypath: The PyTree keypath of the array being saved. + value: The PyTree leaf value (array) being saved. + array_saving_options: The Orbax array saving options to use for resolution. + + Returns: + The resolved StorageOptions containing storage options. + """ + global_opts = array_saving_options.storage_options + if global_opts is None: + global_opts = options_lib.ArrayOptions.Saving.StorageOptions() + + fn = array_saving_options.scoped_storage_options_creator + individual_opts = None + if fn is not None: + individual_opts = fn(keypath, value) + + if individual_opts is not None: + resolved_dtype = ( + individual_opts.dtype + if individual_opts.dtype is not None + else global_opts.dtype + ) + resolved_chunk_byte_size = ( + individual_opts.chunk_byte_size + if individual_opts.chunk_byte_size is not None + else global_opts.chunk_byte_size + ) + resolved_shard_axes = ( + individual_opts.shard_axes + if individual_opts.shard_axes + else global_opts.shard_axes + ) + else: + resolved_dtype = global_opts.dtype + resolved_chunk_byte_size = global_opts.chunk_byte_size + resolved_shard_axes = global_opts.shard_axes + + return options_lib.ArrayOptions.Saving.StorageOptions( + dtype=resolved_dtype, + chunk_byte_size=resolved_chunk_byte_size, + shard_axes=resolved_shard_axes, + ) + diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py new file mode 100644 index 000000000..18f459057 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/options_resolution_test.py @@ -0,0 +1,119 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for utility functions for serialization.""" + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import numpy as np +from orbax.checkpoint.experimental.v1._src.context import context as context_lib +from orbax.checkpoint.experimental.v1._src.context import options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import options_resolution + + +class OptionsResolutionTest(parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='callback_overriding_global', + callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int16 + ), + expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int16, + chunk_byte_size=16_000_000, + shard_axes=(0,), + ), + ), + dict( + testcase_name='callback_overriding_all', + callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.float32, + chunk_byte_size=32_000_000, + shard_axes=(1,), + ), + expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.float32, + chunk_byte_size=32_000_000, + shard_axes=(1,), + ), + ), + dict( + testcase_name='without_callback_falls_back_to_global', + callback=None, + expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int32, + chunk_byte_size=16_000_000, + shard_axes=(0,), + ), + ), + dict( + testcase_name='jnp_dtype_converter', + callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( + dtype=jnp.bfloat16, + ), + expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( + dtype=jnp.bfloat16, + chunk_byte_size=16_000_000, + shard_axes=(0,), + ), + ), + dict( + testcase_name='empty_shard_axes_falls_back_to_global', + callback=lambda k, v: options_lib.ArrayOptions.Saving.StorageOptions( + shard_axes=(), + ), + expected_storage_options=options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int32, + chunk_byte_size=16_000_000, + shard_axes=(0,), + ), + ), + ) + def test_resolve_storage_options( + self, + callback, + expected_storage_options, + ): + # Global options + global_storage = options_lib.ArrayOptions.Saving.StorageOptions( + dtype=np.int32, + chunk_byte_size=16_000_000, + shard_axes=(0,), + ) + + context = context_lib.Context( + array_options=options_lib.ArrayOptions( + saving=options_lib.ArrayOptions.Saving( + storage_options=global_storage, + scoped_storage_options_creator=callback, + ) + ), + ) + + # Dummy param + keypath = (jax.tree_util.DictKey(key='foo'),) + value = np.ones((2, 2)) + + resolved_options = options_resolution.resolve_storage_options( + keypath, value, context.array_options.saving + ) + + self.assertEqual(resolved_options, expected_storage_options) + + +if __name__ == '__main__': + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py index 047b41390..7d0c0c323 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/scalar_leaf_handler.py @@ -27,6 +27,8 @@ from orbax.checkpoint._src.futures import future from orbax.checkpoint._src.serialization import type_handlers as type_handlers_v0 from orbax.checkpoint.experimental.v1._src.context import context as context_lib +import orbax.checkpoint.experimental.v1._src.context.options as options_lib +from orbax.checkpoint.experimental.v1._src.serialization import options_resolution from orbax.checkpoint.experimental.v1._src.serialization import registration from orbax.checkpoint.experimental.v1._src.serialization import types @@ -67,18 +69,18 @@ def _create_v0_saving_paraminfo( def _create_v0_savearg( param: ScalarSerializationParam, - context: context_lib.Context, + array_saving_options: options_lib.ArrayOptions.Saving, ) -> type_handlers_v0.SaveArgs: - """Creates a V0 SaveArgs from V1 params and context for saving.""" - fn = context.pytree_options.saving.create_array_storage_options_fn - if fn: - storage_options = fn(param.keypath, param.value) - else: - storage_options = context.array_options.saving.storage_options + """Creates a V0 SaveArgs from V1 params and array saving options.""" + resolved_options = options_resolution.resolve_storage_options( + param.keypath, param.value, array_saving_options + ) return type_handlers_v0.SaveArgs( - dtype=np.dtype(storage_options.dtype) if storage_options.dtype else None, - chunk_byte_size=storage_options.chunk_byte_size, - shard_axes=storage_options.shard_axes, + dtype=np.dtype(resolved_options.dtype) + if resolved_options.dtype is not None + else None, + chunk_byte_size=resolved_options.chunk_byte_size, + shard_axes=resolved_options.shard_axes, ) @@ -168,7 +170,10 @@ async def serialize( _create_v0_saving_paraminfo(p, self._context, serialization_context) for p in params ] - saveargs = [_create_v0_savearg(p, self._context) for p in params] + saveargs = [ + _create_v0_savearg(p, self._context.array_options.saving) + for p in params + ] commit_futures = await self._handler_impl.serialize( values, paraminfos, saveargs diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index bdcee5205..e4c109051 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -528,15 +528,15 @@ def test_casting(self, original_dtype, save_dtype, load_dtype): 'numpy_array': np.arange(len(jax.devices()), dtype=load_dtype), } - create_array_storage_options_fn = ( + scoped_storage_options_creator = ( lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions( dtype=save_dtype ) ) with ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - saving=ocp.options.PyTreeOptions.Saving( - create_array_storage_options_fn=create_array_storage_options_fn + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving( + scoped_storage_options_creator=scoped_storage_options_creator ) ) ): @@ -1166,7 +1166,7 @@ def test_subchunking(self): self.assertEqual(metadata[k].storage_metadata.chunk_shape, (2,)) with self.subTest('per_key_setting'): - def create_array_storage_options_fn(key, value): + def scoped_storage_options_creator(key, value): del value if 'a' in tree_utils.str_keypath(key): return ocp.options.ArrayOptions.Saving.StorageOptions( @@ -1176,9 +1176,9 @@ def create_array_storage_options_fn(key, value): chunk_byte_size=8, # force divide in 2 subchunks ) with ocp.Context( - pytree_options=ocp.options.PyTreeOptions( - saving=ocp.options.PyTreeOptions.Saving( - create_array_storage_options_fn=create_array_storage_options_fn + array_options=ocp.options.ArrayOptions( + saving=ocp.options.ArrayOptions.Saving( + scoped_storage_options_creator=scoped_storage_options_creator ) ), ): diff --git a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb index 42e731978..f387d901a 100644 --- a/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb +++ b/docs/guides/checkpoint/v1/checkpointing_pytrees.ipynb @@ -820,21 +820,21 @@ }, "outputs": [], "source": [ - "def create_array_storage_options_fn(keypath, value):\n", + "def scoped_storage_options_creator(keypath, value):\n", " del value\n", " last_key = keypath[-1]\n", + " # Override 'a' to int16\n", " if isinstance(last_key, jax.tree_util.GetAttrKey) and last_key.name == 'a':\n", " return ocp.options.ArrayOptions.Saving.StorageOptions(\n", " dtype=np.dtype(np.int16)\n", " )\n", - " else:\n", - " return ocp.options.ArrayOptions.Saving.StorageOptions()\n", - "\n", + " # Return None to use global default storage_options for other leaves\n", + " return None\n", "\n", "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " saving=ocp.options.PyTreeOptions.Saving(\n", - " create_array_storage_options_fn=create_array_storage_options_fn,\n", + " array_options=ocp.options.ArrayOptions(\n", + " saving=ocp.options.ArrayOptions.Saving(\n", + " scoped_storage_options_creator=scoped_storage_options_creator,\n", " )\n", " )\n", "):\n", @@ -880,15 +880,15 @@ }, "outputs": [], "source": [ - "create_array_storage_options_fn = (\n", + "scoped_storage_options_creator = (\n", " lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions(\n", " dtype=np.dtype(np.int16)\n", " )\n", ")\n", "with ocp.Context(\n", - " pytree_options=ocp.options.PyTreeOptions(\n", - " saving=ocp.options.PyTreeOptions.Saving(\n", - " create_array_storage_options_fn=create_array_storage_options_fn\n", + " array_options=ocp.options.ArrayOptions(\n", + " saving=ocp.options.ArrayOptions.Saving(\n", + " scoped_storage_options_creator=scoped_storage_options_creator\n", " )\n", " )\n", "):\n",