diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md
index 9fa4e7192a..10746ddc3f 100644
--- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md
+++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md
@@ -9,41 +9,48 @@ bucket.
## Checkpoint loading priority
-The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
+The system follows a specific order when deciding which checkpoint to load at
+startup. The first valid condition met is the one executed:
1. **Resume Current Run**: If a checkpoint already exists for the current
`run_name`, the system loads the latest fully-saved checkpoint. This is the
default behavior to ensure minimal state loss when resuming after an
interruption.
2. **Load from Specific Path**: The system checks for a user-specified path.
- - If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
- - If `load_full_state_path` is set, we load a full state checkpoint from that path.
- - **Note**: These two options are mutually exclusive and will cause an error if both are set.
-3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
+ - If `load_parameters_path` is set, we load a parameter only checkpoint
+ from that path..
+ - If `load_full_state_path` is set, we load a full state checkpoint from
+ that path.
+ - **Note**: These two options are mutually exclusive and will cause an
+ error if both are set.
+3. **Initialize from Scratch**: We don't load a checkpoint and initialize state
+ instead.
### MaxText configuration
-| Flag | Description | Type | Default |
-| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
-| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
-| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
-| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
-| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
-| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
-| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
-| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
-| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
+Flag | Description | Type | Default
+:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------
+`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False`
+`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True`
+`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000`
+`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False`
+`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | ``` ""``checkpoint_todelete_full_path ``` | Full path to move checkpoints to before deletion. | `string` | `""`
+`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled)
+`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled)
+`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled)
+`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False`
## Storage and format configuration
-These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
-
-| Flag | Description | Type | Default |
-| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
-| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
-| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
-| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
-| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
-| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
-| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
-| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
+These settings control the underlying storage mechanism
+([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
+
+Flag | Description | Type | Default
+:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------
+`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB)
+`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True`
+`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True`
+`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96`
+`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False`
+`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"`
+`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None`
diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py
index cdfde92d50..17d4e79182 100644
--- a/src/maxtext/common/checkpointing.py
+++ b/src/maxtext/common/checkpointing.py
@@ -222,6 +222,8 @@ def create_orbax_checkpoint_manager(
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
enable_autocheckpoint: bool = False,
+ todelete_subdir: Optional[str] = None,
+ todelete_full_path: Optional[str] = None,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
@@ -279,6 +281,8 @@ def create_orbax_checkpoint_manager(
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
async_options=async_options,
+ todelete_subdir=todelete_subdir,
+ todelete_full_path=todelete_full_path,
),
logger=orbax_logger,
)
diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml
index fb58aa79b4..cecfa7869d 100644
--- a/src/maxtext/configs/base.yml
+++ b/src/maxtext/configs/base.yml
@@ -60,6 +60,12 @@ enable_continuous_checkpointing: False
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False
+# Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://)
+checkpoint_todelete_subdir: None
+
+# Full path to move checkpoints to before deletion. For example: "/_trash"
+checkpoint_todelete_full_path: None
+
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
# checkpointing using orbax has two important parameters: array driver
diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py
index 454a9f23f5..6d9f459079 100644
--- a/src/maxtext/configs/types.py
+++ b/src/maxtext/configs/types.py
@@ -313,6 +313,11 @@ class Checkpointing(BaseModel):
enable_single_replica_ckpt_restoring: bool = Field(
False, description="One replica reads and broadcasts the checkpoint."
)
+ checkpoint_todelete_subdir: str | None = Field(
+ None,
+ description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)",
+ )
+ checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.")
force_unroll: bool = Field(
False,
description="During param-only checkpoint generation, whether to unroll the loop.",
diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py
index 54b2755801..2ed71a6e3f 100644
--- a/src/maxtext/utils/train_utils.py
+++ b/src/maxtext/utils/train_utils.py
@@ -83,6 +83,8 @@ def create_training_tools(config, model, mesh):
config.colocated_python_checkpointing,
config.enable_single_replica_ckpt_restoring,
config.enable_autocheckpoint,
+ config.checkpoint_todelete_subdir,
+ config.checkpoint_todelete_full_path,
)
return init_rng, checkpoint_manager, learning_rate_schedule, tx