Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions docs/guides/checkpointing_solutions/gcs_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,48 @@ bucket.

## Checkpoint loading priority

The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
The system follows a specific order when deciding which checkpoint to load at
startup. The first valid condition met is the one executed:

1. **Resume Current Run**: If a checkpoint already exists for the current
`run_name`, the system loads the latest fully-saved checkpoint. This is the
default behavior to ensure minimal state loss when resuming after an
interruption.
2. **Load from Specific Path**: The system checks for a user-specified path.
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
- If `load_parameters_path` is set, we load a parameter only checkpoint
from that path..
- If `load_full_state_path` is set, we load a full state checkpoint from
that path.
- **Note**: These two options are mutually exclusive and will cause an
error if both are set.
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state
instead.

### MaxText configuration

| Flag | Description | Type | Default |
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
Flag | Description | Type | Default
:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------
`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False`
`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True`
`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000`
`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False`
`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` (Ignored if directory is prefixed with gs://) | `string` | ``` ""``checkpoint_todelete_full_path ``` | Full path to move checkpoints to before deletion. | `string` | `""`
`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled)
`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled)
`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled)
`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False`

## Storage and format configuration

These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.

| Flag | Description | Type | Default |
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
These settings control the underlying storage mechanism
([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.

Flag | Description | Type | Default
:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------
`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB)
`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True`
`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True`
`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96`
`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False`
`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"`
`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None`
4 changes: 4 additions & 0 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def create_orbax_checkpoint_manager(
colocated_python_checkpointing: bool = False,
enable_single_replica_ckpt_restoring: bool = False,
enable_autocheckpoint: bool = False,
todelete_subdir: Optional[str] = None,
todelete_full_path: Optional[str] = None,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -279,6 +281,8 @@ def create_orbax_checkpoint_manager(
save_decision_policy=save_decision_policy,
preservation_policy=preservation_policy,
async_options=async_options,
todelete_subdir=todelete_subdir,
todelete_full_path=todelete_full_path,
),
logger=orbax_logger,
)
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ enable_continuous_checkpointing: False
# enables one replica to read the ckpt then broadcast to the rest
enable_single_replica_ckpt_restoring: False

# Subdirectory to move checkpoints to before deletion. For example: ".todelete" (Ignored if directory is prefixed with gs://)
checkpoint_todelete_subdir: None

# Full path to move checkpoints to before deletion. For example: "/_trash"
checkpoint_todelete_full_path: None

force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?

# checkpointing using orbax has two important parameters: array driver
Expand Down
5 changes: 5 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ class Checkpointing(BaseModel):
enable_single_replica_ckpt_restoring: bool = Field(
False, description="One replica reads and broadcasts the checkpoint."
)
checkpoint_todelete_subdir: str | None = Field(
None,
description="Subdirectory to move checkpoints to before deletion. (Ignored if directory is prefixed with gs://)",
)
checkpoint_todelete_full_path: str | None = Field(None, description="Full path to move checkpoints to before deletion.")
force_unroll: bool = Field(
False,
description="During param-only checkpoint generation, whether to unroll the loop.",
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def create_training_tools(config, model, mesh):
config.colocated_python_checkpointing,
config.enable_single_replica_ckpt_restoring,
config.enable_autocheckpoint,
config.checkpoint_todelete_subdir,
config.checkpoint_todelete_full_path,
)

return init_rng, checkpoint_manager, learning_rate_schedule, tx
Expand Down
Loading