Skip to content

Commit 0b1751f

Browse files
authored
Add DCP async checkpointing info to docs (#2837)
1 parent 9983bbc commit 0b1751f

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

docs/source/deep_dives/checkpointer.rst

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,61 @@ for testing or for loading quantized models for generation.
266266

267267
|
268268
269+
:class:`DistributedCheckpointer <torchtune.training.DistributedCheckpointer>`
270+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
271+
272+
This checkpointer reads and writes checkpoints in a distributed format using Pytorch Distributed Checkpointing (DCP).
273+
The output format is DCP's default format, which saves the state dict across all ranks, as seperate files for each rank. This differs
274+
from the other checkpointer implementations, which consolidate to rank-0 and then save the state dict as full tensors.
275+
The distributed checkpointer is enabled when enabling asynchronous checkpointing during training and uses DCP's async_save API.
276+
Async distributed checkpointing is only used for intermediate checkpoints.
277+
278+
When asynchronous checkpointing is enabled, intermediate checkpoints are saved using the DistributedCheckpointer
279+
without blocking the training process. This is particularly useful for large models where saving checkpoints
280+
can take significant time.
281+
282+
**Key Features:**
283+
284+
- **Asynchronous Saving**: Allows training to continue while checkpoints are being saved in the background
285+
- **Distributed-Aware**: Designed to work seamlessly in multi-GPU and multi-node training setups
286+
287+
**Configuration Example:**
288+
289+
To enable asynchronous checkpointing in your training config, you need to set the ``enable_async_checkpointing``
290+
flag to ``True``. The DistributedCheckpointer will be automatically used when this flag is enabled.
291+
292+
.. code-block:: yaml
293+
294+
checkpointer:
295+
# checkpointer to use for final checkpoints
296+
_component_: torchtune.training.FullModelHFCheckpointer
297+
298+
# Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints
299+
enable_async_checkpointing: True
300+
301+
**Resuming Training with DistributedCheckpointer:**
302+
303+
If your training was interrupted and you had async checkpointing enabled, you can resume from the latest
304+
distributed checkpoint by setting both ``resume_from_checkpoint`` and ``enable_async_checkpointing`` to ``True``:
305+
306+
.. code-block:: yaml
307+
308+
# Set to True to resume from checkpoint
309+
resume_from_checkpoint: True
310+
311+
# Set to True to enable asynchronous checkpointing
312+
enable_async_checkpointing: True
313+
314+
The DistributedCheckpointer will automatically locate and load the latest intermediate checkpoint from the
315+
output directory.
316+
317+
.. note::
318+
319+
The final checkpoint at the end of training is always saved synchronously to ensure all
320+
data is properly persisted in safetensors or torch.save format before the training job completes.
321+
322+
|
323+
269324
Checkpoint Output
270325
---------------------------------
271326

0 commit comments

Comments
 (0)