diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 9fa4e7192a..10746ddc3f 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -9,41 +9,48 @@ bucket. ## Checkpoint loading priority -The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed: +The system follows a specific order when deciding which checkpoint to load at +startup. The first valid condition met is the one executed: 1. **Resume Current Run**: If a checkpoint already exists for the current `run_name`, the system loads the latest fully-saved checkpoint. This is the default behavior to ensure minimal state loss when resuming after an interruption. 2. **Load from Specific Path**: The system checks for a user-specified path. - - If `load_parameters_path` is set, we load a parameter only checkpoint from that path.. - - If `load_full_state_path` is set, we load a full state checkpoint from that path. - - **Note**: These two options are mutually exclusive and will cause an error if both are set. -3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead. + - If `load_parameters_path` is set, we load a parameter only checkpoint + from that path.. + - If `load_full_state_path` is set, we load a full state checkpoint from + that path. + - **Note**: These two options are mutually exclusive and will cause an + error if both are set. +3. **Initialize from Scratch**: We don't load a checkpoint and initialize state + instead. ### MaxText configuration -| Flag | Description | Type | Default | -| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- | -| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | -| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | -| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | -| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | -| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | -| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | -| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | -| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` | +Flag | Description | Type | Default +:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------ +`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` +`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` +`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` +`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` +`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | ``` ""``checkpoint_todelete_full_path ``` | Full path to move checkpoints to before deletion. | `string` | `""` +`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) +`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) +`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) +`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` ## Storage and format configuration -These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. - -| Flag | Description | Type | Default | -| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ | -| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) | -| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` | -| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` | -| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` | -| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` | -| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` | -| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` | +These settings control the underlying storage mechanism +([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. + +Flag | Description | Type | Default +:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------ +`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) +`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` +`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` +`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` +`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` +`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index cdfde92d50..17d4e79182 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -222,6 +222,8 @@ def create_orbax_checkpoint_manager( colocated_python_checkpointing: bool = False, enable_single_replica_ckpt_restoring: bool = False, enable_autocheckpoint: bool = False, + todelete_subdir: Optional[str] = None, + todelete_full_path: Optional[str] = None, ): """Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled.""" if not enable_checkpointing: @@ -279,6 +281,8 @@ def create_orbax_checkpoint_manager( save_decision_policy=save_decision_policy, preservation_policy=preservation_policy, async_options=async_options, + todelete_subdir=todelete_subdir, + todelete_full_path=todelete_full_path, ), logger=orbax_logger, ) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index fb58aa79b4..cecfa7869d 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -60,6 +60,12 @@ enable_continuous_checkpointing: False # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False +# Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://) +checkpoint_todelete_subdir: None + +# Full path to move checkpoints to before deletion. For example: "/_trash" +checkpoint_todelete_full_path: None + force_unroll: False # during generate_param_only_checkpoint should we unroll the loop? # checkpointing using orbax has two important parameters: array driver diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 454a9f23f5..6d9f459079 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -313,6 +313,11 @@ class Checkpointing(BaseModel): enable_single_replica_ckpt_restoring: bool = Field( False, description="One replica reads and broadcasts the checkpoint." ) + checkpoint_todelete_subdir: str | None = Field( + None, + description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)", + ) + checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.") force_unroll: bool = Field( False, description="During param-only checkpoint generation, whether to unroll the loop.", diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 54b2755801..2ed71a6e3f 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -83,6 +83,8 @@ def create_training_tools(config, model, mesh): config.colocated_python_checkpointing, config.enable_single_replica_ckpt_restoring, config.enable_autocheckpoint, + config.checkpoint_todelete_subdir, + config.checkpoint_todelete_full_path, ) return init_rng, checkpoint_manager, learning_rate_schedule, tx