Skip to content

Add DCP async checkpointing info to docs #2837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/source/deep_dives/checkpointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,61 @@ for testing or for loading quantized models for generation.

|

:class:`DistributedCheckpointer <torchtune.training.DistributedCheckpointer>`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This checkpointer reads and writes checkpoints in a distributed format using Pytorch Distributed Checkpointing (DCP).
The output format is DCP's default format, which saves the state dict across all ranks, as seperate files for each rank. This differs
from the other checkpointer implementations, which consolidate to rank-0 and then save the state dict as full tensors.
The distributed checkpointer is enabled when enabling asynchronous checkpointing during training and uses DCP's async_save API.
Async distributed checkpointing is only used for intermediate checkpoints.

When asynchronous checkpointing is enabled, intermediate checkpoints are saved using the DistributedCheckpointer
without blocking the training process. This is particularly useful for large models where saving checkpoints
can take significant time.

**Key Features:**

- **Asynchronous Saving**: Allows training to continue while checkpoints are being saved in the background
- **Distributed-Aware**: Designed to work seamlessly in multi-GPU and multi-node training setups

**Configuration Example:**

To enable asynchronous checkpointing in your training config, you need to set the ``enable_async_checkpointing``
flag to ``True``. The DistributedCheckpointer will be automatically used when this flag is enabled.

.. code-block:: yaml

checkpointer:
# checkpointer to use for final checkpoints
_component_: torchtune.training.FullModelHFCheckpointer

# Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints
enable_async_checkpointing: True

**Resuming Training with DistributedCheckpointer:**

If your training was interrupted and you had async checkpointing enabled, you can resume from the latest
distributed checkpoint by setting both ``resume_from_checkpoint`` and ``enable_async_checkpointing`` to ``True``:

.. code-block:: yaml

# Set to True to resume from checkpoint
resume_from_checkpoint: True

# Set to True to enable asynchronous checkpointing
enable_async_checkpointing: True

The DistributedCheckpointer will automatically locate and load the latest intermediate checkpoint from the
output directory.

.. note::

The final checkpoint at the end of training is always saved synchronously to ensure all
data is properly persisted in safetensors or torch.save format before the training job completes.

|

Checkpoint Output
---------------------------------

Expand Down
Loading