@@ -266,6 +266,61 @@ for testing or for loading quantized models for generation.
266
266
267
267
|
268
268
269
+ :class: `DistributedCheckpointer <torchtune.training.DistributedCheckpointer> `
270
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
271
+
272
+ This checkpointer reads and writes checkpoints in a distributed format using Pytorch Distributed Checkpointing (DCP).
273
+ The output format is DCP's default format, which saves the state dict across all ranks, as seperate files for each rank. This differs
274
+ from the other checkpointer implementations, which consolidate to rank-0 and then save the state dict as full tensors.
275
+ The distributed checkpointer is enabled when enabling asynchronous checkpointing during training and uses DCP's async_save API.
276
+ Async distributed checkpointing is only used for intermediate checkpoints.
277
+
278
+ When asynchronous checkpointing is enabled, intermediate checkpoints are saved using the DistributedCheckpointer
279
+ without blocking the training process. This is particularly useful for large models where saving checkpoints
280
+ can take significant time.
281
+
282
+ **Key Features: **
283
+
284
+ - **Asynchronous Saving **: Allows training to continue while checkpoints are being saved in the background
285
+ - **Distributed-Aware **: Designed to work seamlessly in multi-GPU and multi-node training setups
286
+
287
+ **Configuration Example: **
288
+
289
+ To enable asynchronous checkpointing in your training config, you need to set the ``enable_async_checkpointing ``
290
+ flag to ``True ``. The DistributedCheckpointer will be automatically used when this flag is enabled.
291
+
292
+ .. code-block :: yaml
293
+
294
+ checkpointer :
295
+ # checkpointer to use for final checkpoints
296
+ _component_ : torchtune.training.FullModelHFCheckpointer
297
+
298
+ # Set to True to enable asynchronous distributed checkpointing for intermediate checkpoints
299
+ enable_async_checkpointing : True
300
+
301
+ **Resuming Training with DistributedCheckpointer: **
302
+
303
+ If your training was interrupted and you had async checkpointing enabled, you can resume from the latest
304
+ distributed checkpoint by setting both ``resume_from_checkpoint `` and ``enable_async_checkpointing `` to ``True ``:
305
+
306
+ .. code-block :: yaml
307
+
308
+ # Set to True to resume from checkpoint
309
+ resume_from_checkpoint : True
310
+
311
+ # Set to True to enable asynchronous checkpointing
312
+ enable_async_checkpointing : True
313
+
314
+ The DistributedCheckpointer will automatically locate and load the latest intermediate checkpoint from the
315
+ output directory.
316
+
317
+ .. note ::
318
+
319
+ The final checkpoint at the end of training is always saved synchronously to ensure all
320
+ data is properly persisted in safetensors or torch.save format before the training job completes.
321
+
322
+ |
323
+
269
324
Checkpoint Output
270
325
---------------------------------
271
326
0 commit comments