From 3a1b06229162f1e99a4e1af722f7e2bc85235404 Mon Sep 17 00:00:00 2001 From: Adam Wonak Date: Mon, 23 Mar 2026 23:39:06 +0000 Subject: [PATCH 1/5] Add checkpoint deletion options to configuration and checkpoint manager --- src/maxtext/common/checkpointing.py | 4 ++++ src/maxtext/configs/base.yml | 6 ++++++ src/maxtext/configs/types.py | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 220ff6f16d..7b522acc7e 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -221,6 +221,8 @@ def create_orbax_checkpoint_manager( enable_single_controller: bool = False, colocated_python_checkpointing: bool = False, enable_single_replica_ckpt_restoring: bool = False, + todelete_subdir: Optional[str] = None, + todelete_full_path: Optional[str] = None, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -268,6 +270,8 @@ def create_orbax_checkpoint_manager( save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, async_options=async_options, + todelete_subdir=todelete_subdir, + todelete_full_path=todelete_full_path, ), logger=orbax_logger, ) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 85db2bcce4..c7750c5b98 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -60,6 +60,12 @@ enable_continuous_checkpointing: False # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False +# Subdirectory to move checkpoints to before deletion. For example: ".todelete" +checkpoint_todelete_subdir: "" + +# Full path to move checkpoints to before deletion. +checkpoint_todelete_full_path: "" + force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? # checkpointing using orbax has two important parameters: array driver diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9ee1a7bb59..08f7469806 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -309,6 +309,12 @@ class Checkpointing(BaseModel): enable_single_replica_ckpt_restoring: bool = Field( False, description="One replica reads and broadcasts the checkpoint." ) + checkpoint_todelete_subdir: str = Field( + "", description="Subdirectory to move checkpoints to before deletion." + ) + checkpoint_todelete_full_path: str = Field( + "", description="Full path to move checkpoints to before deletion." + ) force_unroll: bool = Field( False, description="During param-only checkpoint generation, whether to unroll the loop.", From 39acb6ace66f4c19a2d20e09ec727b22308506d8 Mon Sep 17 00:00:00 2001 From: Adam Wonak Date: Mon, 23 Mar 2026 23:47:57 +0000 Subject: [PATCH 2/5] update documentation to include newly added orbax flags --- docs/guides/checkpointing_solutions/gcs_checkpointing.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 9fa4e7192a..56004c2466 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -29,6 +29,8 @@ The system follows a specific order when deciding which checkpoint to load at st | `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | | `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | | `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | +| `checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` | `string` | `""` | +| `checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` | | `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | | `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | | `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | From 9a662a5566bf544a454deb0b854b60afe361585c Mon Sep 17 00:00:00 2001 From: Adam Wonak Date: Mon, 30 Mar 2026 17:58:07 +0000 Subject: [PATCH 3/5] Default to None to avoid setting both flags simultaneously. --- src/maxtext/configs/base.yml | 6 +++--- src/maxtext/configs/types.py | 8 ++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index c7750c5b98..a692824aaa 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -61,10 +61,10 @@ enable_continuous_checkpointing: False enable_single_replica_ckpt_restoring: False # Subdirectory to move checkpoints to before deletion. For example: ".todelete" -checkpoint_todelete_subdir: "" +checkpoint_todelete_subdir: None -# Full path to move checkpoints to before deletion. -checkpoint_todelete_full_path: "" +# Full path to move checkpoints to before deletion. For example: "/_trash" +checkpoint_todelete_full_path: None force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 08f7469806..420a74c547 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -309,12 +309,8 @@ class Checkpointing(BaseModel): enable_single_replica_ckpt_restoring: bool = Field( False, description="One replica reads and broadcasts the checkpoint." ) - checkpoint_todelete_subdir: str = Field( - "", description="Subdirectory to move checkpoints to before deletion." - ) - checkpoint_todelete_full_path: str = Field( - "", description="Full path to move checkpoints to before deletion." - ) + checkpoint_todelete_subdir: str | None = Field(None, description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)") + checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.") force_unroll: bool = Field( False, description="During param-only checkpoint generation, whether to unroll the loop.", From 35ff419ee6121905af4a0acbf10769f355b353f7 Mon Sep 17 00:00:00 2001 From: Adam Wonak Date: Mon, 30 Mar 2026 18:13:33 +0000 Subject: [PATCH 4/5] formatting --- .../gcs_checkpointing.md | 74 ++++++++++--------- src/maxtext/configs/base.yml | 2 +- src/maxtext/configs/types.py | 5 +- 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 56004c2466..2ea8e0ce20 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -9,43 +9,49 @@ bucket. ## Checkpoint loading priority -The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed: - -1. **Resume Current Run**: If a checkpoint already exists for the current - `run_name`, the system loads the latest fully-saved checkpoint. This is the - default behavior to ensure minimal state loss when resuming after an - interruption. -2. **Load from Specific Path**: The system checks for a user-specified path. - - If `load_parameters_path` is set, we load a parameter only checkpoint from that path.. - - If `load_full_state_path` is set, we load a full state checkpoint from that path. - - **Note**: These two options are mutually exclusive and will cause an error if both are set. -3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead. +The system follows a specific order when deciding which checkpoint to load at +startup. The first valid condition met is the one executed: + +1. **Resume Current Run**: If a checkpoint already exists for the current + `run_name`, the system loads the latest fully-saved checkpoint. This is the + default behavior to ensure minimal state loss when resuming after an + interruption. +2. **Load from Specific Path**: The system checks for a user-specified path. + - If `load_parameters_path` is set, we load a parameter only checkpoint + from that path.. + - If `load_full_state_path` is set, we load a full state checkpoint from + that path. + - **Note**: These two options are mutually exclusive and will cause an + error if both are set. +3. **Initialize from Scratch**: We don't load a checkpoint and initialize state + instead. ### MaxText configuration -| Flag | Description | Type | Default | -| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- | -| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | -| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | -| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | -| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | -| `checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` | `string` | `""` | -| `checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` | -| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | -| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | -| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | -| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` | +Flag | Description | Type | Default +:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------ +`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` +`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` +`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` +`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` +`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"`. (Ignored if directory is prefixed with gs://) | `string` | `""` +`checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` +`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) +`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) +`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) +`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` ## Storage and format configuration -These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. - -| Flag | Description | Type | Default | -| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ | -| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) | -| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` | -| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` | -| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` | -| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` | -| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` | -| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` | +These settings control the underlying storage mechanism +([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. + +Flag | Description | Type | Default +:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------ +`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) +`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` +`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` +`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` +`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index a692824aaa..900955390f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -60,7 +60,7 @@ enable_continuous_checkpointing: False # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False -# Subdirectory to move checkpoints to before deletion. For example: ".todelete" +# Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://) checkpoint_todelete_subdir: None # Full path to move checkpoints to before deletion. For example: "/_trash" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 420a74c547..47025355ee 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -309,7 +309,10 @@ class Checkpointing(BaseModel): enable_single_replica_ckpt_restoring: bool = Field( False, description="One replica reads and broadcasts the checkpoint." ) - checkpoint_todelete_subdir: str | None = Field(None, description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)") + checkpoint_todelete_subdir: str | None = Field( + None, + description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)", + ) checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.") force_unroll: bool = Field( False, From 112c3e97155962f67378d0ace4345abd1a504427 Mon Sep 17 00:00:00 2001 From: Adam Wonak Date: Mon, 6 Apr 2026 21:05:54 +0000 Subject: [PATCH 5/5] Ensure flag gets used. update docs. --- .../gcs_checkpointing.md | 61 +++++++++---------- src/maxtext/utils/train_utils.py | 2 + 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 2ea8e0ce20..10746ddc3f 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -12,46 +12,45 @@ bucket. The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed: -1. **Resume Current Run**: If a checkpoint already exists for the current - `run_name`, the system loads the latest fully-saved checkpoint. This is the - default behavior to ensure minimal state loss when resuming after an - interruption. -2. **Load from Specific Path**: The system checks for a user-specified path. - - If `load_parameters_path` is set, we load a parameter only checkpoint - from that path.. - - If `load_full_state_path` is set, we load a full state checkpoint from - that path. - - **Note**: These two options are mutually exclusive and will cause an - error if both are set. -3. **Initialize from Scratch**: We don't load a checkpoint and initialize state - instead. +1. **Resume Current Run**: If a checkpoint already exists for the current + `run_name`, the system loads the latest fully-saved checkpoint. This is the + default behavior to ensure minimal state loss when resuming after an + interruption. +2. **Load from Specific Path**: The system checks for a user-specified path. + - If `load_parameters_path` is set, we load a parameter only checkpoint + from that path.. + - If `load_full_state_path` is set, we load a full state checkpoint from + that path. + - **Note**: These two options are mutually exclusive and will cause an + error if both are set. +3. **Initialize from Scratch**: We don't load a checkpoint and initialize state + instead. ### MaxText configuration -Flag | Description | Type | Default +Flag | Description | Type | Default :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------ -`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` -`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` -`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` +`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` +`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` +`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` -`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"`. (Ignored if directory is prefixed with gs://) | `string` | `""` -`checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""` -`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) -`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) -`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) -`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` +`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | ``` ""``checkpoint_todelete_full_path ``` | Full path to move checkpoints to before deletion. | `string` | `""` +`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) +`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) +`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) +`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` ## Storage and format configuration These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. -Flag | Description | Type | Default +Flag | Description | Type | Default :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------ -`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) -`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` -`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` -`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` -`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` -`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` -`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` +`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) +`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` +`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` +`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` +`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 00eb408ad3..3dd5a14638 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -82,6 +82,8 @@ def create_training_tools(config, model, mesh): config.enable_single_controller, config.colocated_python_checkpointing, config.enable_single_replica_ckpt_restoring, + config.checkpoint_todelete_subdir, + config.checkpoint_todelete_full_path, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx