26
26
AdapterModule ,
27
27
disable_adapter ,
28
28
get_adapter_params ,
29
- get_adapter_state_dict ,
30
29
get_lora_module_names ,
31
- get_merged_lora_ckpt ,
32
30
set_trainable_params ,
33
31
validate_missing_and_unexpected_for_lora ,
34
32
)
35
33
from torchtune .recipe_interfaces import FTRecipeInterface
36
34
from torchtune .rlhf import ChosenRejectedOutputs
37
35
from torchtune .training import VALID_BACKENDS_FOR_MEMORY_STATS
36
+ from torchtune .training .checkpointing ._checkpoint_client import (
37
+ CheckpointClient ,
38
+ TrainingProgress ,
39
+ )
38
40
from tqdm import tqdm
39
41
40
42
@@ -132,11 +134,14 @@ def __init__(self, cfg: DictConfig) -> None:
132
134
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
133
135
)
134
136
137
+ self ._checkpoint_client = CheckpointClient (cfg )
135
138
# Set up the backend for distributed training (NCCL, GLOO, etc.)
136
139
self ._enable_async_checkpointing = cfg .get ("enable_async_checkpointing" , False )
137
140
self .fsdp_cpu_offload = cfg .get ("fsdp_cpu_offload" , False )
138
141
self .distributed_backend = training .get_distributed_backend (
139
- cfg .device , offload_ops_to_cpu = True
142
+ cfg .device ,
143
+ offload_ops_to_cpu = self .fsdp_cpu_offload
144
+ or self ._enable_async_checkpointing ,
140
145
)
141
146
init_process_group (self .distributed_backend )
142
147
@@ -196,31 +201,6 @@ def __init__(self, cfg: DictConfig) -> None:
196
201
self ._save_adapter_weights_only = cfg .get ("save_adapter_weights_only" , False )
197
202
self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
198
203
199
- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> dict [str , Any ]:
200
- """
201
- Extract the checkpoint state from file and validate. This includes the
202
- base model weights. If resume_from_checkpoint is True, this also includes
203
- the adapter weights and recipe state
204
- """
205
- self ._checkpointer = config .instantiate (
206
- cfg_checkpointer ,
207
- should_load_recipe_state = self ._resume_from_checkpoint ,
208
- )
209
- checkpoint_dict = self ._checkpointer .load_checkpoint ()
210
-
211
- # When resuming from checkpoint for LoRA, the recipe expects the adapter weights
212
- # and recipe state to be present. The keys should match up with what ``save_checkpoint``
213
- # used to create these intermediate checkpoints
214
- if self ._resume_from_checkpoint :
215
- if training .ADAPTER_KEY not in checkpoint_dict :
216
- raise ValueError (
217
- "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
218
- )
219
- # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
220
- # no need to check here
221
- self ._update_recipe_state (checkpoint_dict )
222
- return checkpoint_dict
223
-
224
204
def _update_recipe_state (self , ckpt_dict : dict [str , Any ]) -> None :
225
205
"""
226
206
Updates the recipe state from checkpoint.
@@ -274,7 +254,7 @@ def setup(self, cfg: DictConfig) -> None:
274
254
275
255
utils .log_rank_zero (self ._logger , "metric logger is initialized." )
276
256
277
- checkpoint_dict = self .load_checkpoint ( cfg_checkpointer = cfg . checkpointer )
257
+ checkpoint_dict = self ._checkpoint_client . load_base_checkpoint ( )
278
258
279
259
self ._model = self ._setup_model (
280
260
cfg_model = cfg .model ,
@@ -286,7 +266,7 @@ def setup(self, cfg: DictConfig) -> None:
286
266
base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
287
267
lora_weights_state_dict = (
288
268
checkpoint_dict [training .ADAPTER_KEY ]
289
- if self . _resume_from_checkpoint
269
+ if training . ADAPTER_KEY in checkpoint_dict
290
270
else None
291
271
),
292
272
)
@@ -296,11 +276,38 @@ def setup(self, cfg: DictConfig) -> None:
296
276
cfg_optimizer = cfg .optimizer ,
297
277
opt_state_dict = (
298
278
checkpoint_dict [training .OPT_KEY ]
299
- if self . _resume_from_checkpoint
279
+ if training . OPT_KEY in checkpoint_dict
300
280
else None
301
281
),
302
282
)
303
283
284
+ if self ._resume_from_checkpoint :
285
+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
286
+ # using the DistributedCheckpointer.
287
+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
288
+ # progress.
289
+ if self ._enable_async_checkpointing :
290
+ try :
291
+ checkpoint_dict = (
292
+ self ._checkpoint_client .load_distributed_checkpoint (
293
+ self ._model ,
294
+ self ._optimizer ,
295
+ self ._adapter_config ,
296
+ )
297
+ )
298
+ except Exception as e :
299
+ log .warning (
300
+ f"Failed to load distributed checkpoint: { e } . Training will start from the base checkpoint."
301
+ )
302
+
303
+ if training .ADAPTER_KEY not in checkpoint_dict :
304
+ raise ValueError (
305
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
306
+ )
307
+
308
+ # Update the recipe state from the checkpoint state dict.
309
+ self ._update_recipe_state (checkpoint_dict )
310
+
304
311
self ._loss_fn = config .instantiate (cfg .loss )
305
312
306
313
utils .log_rank_zero (self ._logger , "Loss is initialized." )
@@ -363,6 +370,17 @@ def _setup_model(
363
370
self ._apply_lora_to_mlp = cfg_model .apply_lora_to_mlp
364
371
self ._apply_lora_to_output = getattr (cfg_model , "apply_lora_to_output" , False )
365
372
373
+ self ._adapter_config = {
374
+ "r" : self ._lora_rank ,
375
+ "lora_alpha" : self ._lora_alpha ,
376
+ "target_modules" : get_lora_module_names (
377
+ self ._lora_attn_modules ,
378
+ self ._apply_lora_to_mlp ,
379
+ self ._apply_lora_to_output ,
380
+ ),
381
+ "peft_type" : "LORA" ,
382
+ }
383
+
366
384
init_start = time .perf_counter ()
367
385
368
386
utils .log_rank_zero (
@@ -541,89 +559,20 @@ def save_checkpoint(
541
559
self ,
542
560
epoch : int ,
543
561
) -> None :
544
- """
545
- Checkpoint the state of the recipe. The constructed checkpoint state dict
546
- contains the following information:
547
- - Merged weights with key MODEL_KEY
548
- - Adapter weights with key ADAPTER_KEY
549
- - Relevant recipe state if training is not complete
550
- - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
551
-
552
- Checkpointer will save the merged weights, adapter weights and recipe state in
553
- different checkpoint files. To correctly resume from training, the adapter weights
554
- and recipe state must be provided along with the base model weights."""
555
- # final dict passed onto the checkpointer
556
- checkpoint_dict = {}
557
-
558
- intermediate_checkpoint = epoch + 1 < self .total_epochs
559
- # To prevent GPU memory from spiking during checkpoint save,
560
- # we consolidate the full model and optim state dicts on CPU for rank 0
561
- cpu_state_dict = training .gather_cpu_state_dict (
562
- self ._model ,
563
- self ._is_rank_zero ,
564
- device = self ._device ,
565
- adapter_weights_only = self ._save_adapter_weights_only ,
562
+ self ._checkpoint_client .save_checkpoint (
563
+ model = self ._model ,
564
+ optimizer = self ._optimizer ,
565
+ training_progress = TrainingProgress (
566
+ seed = self .seed ,
567
+ epochs_run = self .epochs_run ,
568
+ total_epochs = self .total_epochs ,
569
+ max_steps_per_epoch = self .max_steps_per_epoch ,
570
+ dataloader_state_dict = self ._dataloader .state_dict (),
571
+ ),
572
+ epoch = epoch ,
573
+ adapter_config = self ._adapter_config .copy (),
574
+ adapter_only = self ._save_adapter_weights_only ,
566
575
)
567
- if intermediate_checkpoint :
568
- opt_state_dict = training .get_full_optimizer_state_dict (
569
- self ._model ,
570
- self ._optimizer ,
571
- self ._is_rank_zero ,
572
- device = self ._device ,
573
- )
574
- else :
575
- opt_state_dict = None
576
-
577
- # Now that we have the model and opt state dict, create the actual checkpoint dict
578
- # to be sent to the checkpointer and ultimately written to file
579
- if self ._is_rank_zero :
580
- if self ._save_adapter_weights_only :
581
- adapter_state_dict = cpu_state_dict
582
- else :
583
- # Filter out the adapter keys and weights from the model state dict. These will
584
- # be saved separately
585
- adapter_state_dict = get_adapter_state_dict (cpu_state_dict )
586
-
587
- # merge the adapter weights and base weights to create the model checkpoint
588
- merged_state_dict = get_merged_lora_ckpt (
589
- cpu_state_dict ,
590
- rank = self ._lora_rank ,
591
- alpha = self ._lora_alpha ,
592
- )
593
- checkpoint_dict .update ({training .MODEL_KEY : merged_state_dict })
594
- checkpoint_dict .update ({training .ADAPTER_KEY : adapter_state_dict })
595
-
596
- # if training is in-progress, checkpoint the optimizer state and recipe state
597
- # as well.
598
- if intermediate_checkpoint :
599
- checkpoint_dict .update (
600
- {
601
- training .OPT_KEY : opt_state_dict ,
602
- training .SEED_KEY : self .seed ,
603
- training .EPOCHS_KEY : self .epochs_run ,
604
- training .TOTAL_EPOCHS_KEY : self .total_epochs ,
605
- training .MAX_STEPS_KEY : self .max_steps_per_epoch ,
606
- training .DATALOADER_KEY : self ._dataloader .state_dict (),
607
- }
608
- )
609
-
610
- adapter_config = {
611
- "r" : self ._lora_rank ,
612
- "lora_alpha" : self ._lora_alpha ,
613
- "target_modules" : get_lora_module_names (
614
- self ._lora_attn_modules ,
615
- self ._apply_lora_to_mlp ,
616
- self ._apply_lora_to_output ,
617
- ),
618
- "peft_type" : "LORA" ,
619
- }
620
- checkpoint_dict .update ({training .ADAPTER_CONFIG : adapter_config })
621
- self ._checkpointer .save_checkpoint (
622
- checkpoint_dict ,
623
- epoch = epoch ,
624
- intermediate_checkpoint = intermediate_checkpoint ,
625
- adapter_only = self ._save_adapter_weights_only ,
626
- )
627
576
628
577
def concatenated_forward (
629
578
self , model : nn .Module , batch : tuple [torch .Tensor , torch .Tensor ]
0 commit comments