From df4b4d57fcf77ba5a8f1e1ae12703432404d53ef Mon Sep 17 00:00:00 2001 From: Ankita George Date: Wed, 18 Jun 2025 13:44:19 -0700 Subject: [PATCH 1/2] add dcp checkpointing to docs --- docs/source/deep_dives/checkpointer.rst | 56 +++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/docs/source/deep_dives/checkpointer.rst b/docs/source/deep_dives/checkpointer.rst index 62e1c0f27f..692c6962b2 100644 --- a/docs/source/deep_dives/checkpointer.rst +++ b/docs/source/deep_dives/checkpointer.rst @@ -266,6 +266,62 @@ for testing or for loading quantized models for generation. | +:class:`DistributedCheckpointer ` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This checkpointer reads and writes checkpoints in a distributed format using Pytorch Distributed Checkpointing (DCP). +The output format is DCP's default format, which saves the state dict across all ranks, as seperate files for each rank. This differs +from the other checkpointer implementations, which consolidate to rank-0 and then save the state dict as full tensors. +The distributed checkpointer is enabled when enabling asynchronous checkpointing during training and uses DCP's async_save API. +Async distributed checkpointing is only used for intermediate checkpoints. + +When asynchronous checkpointing is enabled, intermediate checkpoints are saved using the DistributedCheckpointer +without blocking the training process. This is particularly useful for large models where saving checkpoints +can take significant time. + +**Key Features:** + +- **Asynchronous Saving**: Allows training to continue while checkpoints are being saved in the background +- **Distributed-Aware**: Designed to work seamlessly in multi-GPU and multi-node training setups + +**Configuration Example:** + +To enable asynchronous checkpointing in your training config, you need to set the ``enable_async_checkpointing`` +flag to ``True``. The DistributedCheckpointer will be automatically used when this flag is enabled. + +.. code-block:: yaml + + checkpointer: + # checkpointer to use for final checkpoints + _component_: torchtune.training.FullModelMetaCheckpointer + + # Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints + enable_async_checkpointing: True + +**Resuming Training with DistributedCheckpointer:** + +If your training was interrupted and you had async checkpointing enabled, you can resume from the latest +distributed checkpoint by setting both ``resume_from_checkpoint`` and ``enable_async_checkpointing`` to ``True``: + +.. code-block:: yaml + + # Set to True to resume from checkpoint + resume_from_checkpoint: True + + # Set to True to enable asynchronous checkpointing + enable_async_checkpointing: True + +The DistributedCheckpointer will automatically locate and load the latest intermediate checkpoint from the +output directory. + +.. note:: + + While asynchronous checkpointing improves training throughput, it's important to note that the final + checkpoint at the end of training is always saved synchronously to ensure all data is properly persisted + in a safetensors or torch.save format before the training job completes. + +| + Checkpoint Output --------------------------------- From 9b4deb292659543c57c8d08ac7676b88287f4195 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 23 Jun 2025 16:19:09 -0700 Subject: [PATCH 2/2] use hf checkpointer --- docs/source/deep_dives/checkpointer.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/deep_dives/checkpointer.rst b/docs/source/deep_dives/checkpointer.rst index 692c6962b2..1be305b80a 100644 --- a/docs/source/deep_dives/checkpointer.rst +++ b/docs/source/deep_dives/checkpointer.rst @@ -293,7 +293,7 @@ flag to ``True``. The DistributedCheckpointer will be automatically used when th checkpointer: # checkpointer to use for final checkpoints - _component_: torchtune.training.FullModelMetaCheckpointer + _component_: torchtune.training.FullModelHFCheckpointer # Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints enable_async_checkpointing: True @@ -316,9 +316,8 @@ output directory. .. note:: - While asynchronous checkpointing improves training throughput, it's important to note that the final - checkpoint at the end of training is always saved synchronously to ensure all data is properly persisted - in a safetensors or torch.save format before the training job completes. + The final checkpoint at the end of training is always saved synchronously to ensure all + data is properly persisted in safetensors or torch.save format before the training job completes. |