diff --git a/docs/source/deep_dives/checkpointer.rst b/docs/source/deep_dives/checkpointer.rst index 62e1c0f27f..1be305b80a 100644 --- a/docs/source/deep_dives/checkpointer.rst +++ b/docs/source/deep_dives/checkpointer.rst @@ -266,6 +266,61 @@ for testing or for loading quantized models for generation. | +:class:`DistributedCheckpointer ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This checkpointer reads and writes checkpoints in a distributed format using Pytorch Distributed Checkpointing (DCP). +The output format is DCP's default format, which saves the state dict across all ranks, as seperate files for each rank. This differs +from the other checkpointer implementations, which consolidate to rank-0 and then save the state dict as full tensors. +The distributed checkpointer is enabled when enabling asynchronous checkpointing during training and uses DCP's async_save API. +Async distributed checkpointing is only used for intermediate checkpoints. + +When asynchronous checkpointing is enabled, intermediate checkpoints are saved using the DistributedCheckpointer +without blocking the training process. This is particularly useful for large models where saving checkpoints +can take significant time. + +**Key Features:** + +- **Asynchronous Saving**: Allows training to continue while checkpoints are being saved in the background +- **Distributed-Aware**: Designed to work seamlessly in multi-GPU and multi-node training setups + +**Configuration Example:** + +To enable asynchronous checkpointing in your training config, you need to set the ``enable_async_checkpointing`` +flag to ``True``. The DistributedCheckpointer will be automatically used when this flag is enabled. + +.. code-block:: yaml + + checkpointer: + # checkpointer to use for final checkpoints + _component_: torchtune.training.FullModelHFCheckpointer + + # Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints + enable_async_checkpointing: True + +**Resuming Training with DistributedCheckpointer:** + +If your training was interrupted and you had async checkpointing enabled, you can resume from the latest +distributed checkpoint by setting both ``resume_from_checkpoint`` and ``enable_async_checkpointing`` to ``True``: + +.. code-block:: yaml + + # Set to True to resume from checkpoint + resume_from_checkpoint: True + + # Set to True to enable asynchronous checkpointing + enable_async_checkpointing: True + +The DistributedCheckpointer will automatically locate and load the latest intermediate checkpoint from the +output directory. + +.. note:: + + The final checkpoint at the end of training is always saved synchronously to ensure all + data is properly persisted in safetensors or torch.save format before the training job completes. + +| + Checkpoint Output ---------------------------------