Skip to content

Commit 9e7fe63

Browse files
committed
lora dpo dcp
1 parent b0f211e commit 9e7fe63

File tree

2 files changed

+149
-113
lines changed

2 files changed

+149
-113
lines changed

recipes/lora_dpo_distributed.py

Lines changed: 62 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@
2626
AdapterModule,
2727
disable_adapter,
2828
get_adapter_params,
29-
get_adapter_state_dict,
3029
get_lora_module_names,
31-
get_merged_lora_ckpt,
3230
set_trainable_params,
3331
validate_missing_and_unexpected_for_lora,
3432
)
3533
from torchtune.recipe_interfaces import FTRecipeInterface
3634
from torchtune.rlhf import ChosenRejectedOutputs
3735
from torchtune.training import VALID_BACKENDS_FOR_MEMORY_STATS
36+
from torchtune.training.checkpointing._checkpoint_client import (
37+
CheckpointClient,
38+
TrainingProgress,
39+
)
3840
from tqdm import tqdm
3941

4042

@@ -132,11 +134,14 @@ def __init__(self, cfg: DictConfig) -> None:
132134
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
133135
)
134136

137+
self._checkpoint_client = CheckpointClient(cfg)
135138
# Set up the backend for distributed training (NCCL, GLOO, etc.)
136139
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
137140
self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
138141
self.distributed_backend = training.get_distributed_backend(
139-
cfg.device, offload_ops_to_cpu=True
142+
cfg.device,
143+
offload_ops_to_cpu=self.fsdp_cpu_offload
144+
or self._enable_async_checkpointing,
140145
)
141146
init_process_group(self.distributed_backend)
142147

@@ -196,31 +201,6 @@ def __init__(self, cfg: DictConfig) -> None:
196201
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
197202
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
198203

199-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> dict[str, Any]:
200-
"""
201-
Extract the checkpoint state from file and validate. This includes the
202-
base model weights. If resume_from_checkpoint is True, this also includes
203-
the adapter weights and recipe state
204-
"""
205-
self._checkpointer = config.instantiate(
206-
cfg_checkpointer,
207-
should_load_recipe_state=self._resume_from_checkpoint,
208-
)
209-
checkpoint_dict = self._checkpointer.load_checkpoint()
210-
211-
# When resuming from checkpoint for LoRA, the recipe expects the adapter weights
212-
# and recipe state to be present. The keys should match up with what ``save_checkpoint``
213-
# used to create these intermediate checkpoints
214-
if self._resume_from_checkpoint:
215-
if training.ADAPTER_KEY not in checkpoint_dict:
216-
raise ValueError(
217-
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
218-
)
219-
# _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
220-
# no need to check here
221-
self._update_recipe_state(checkpoint_dict)
222-
return checkpoint_dict
223-
224204
def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None:
225205
"""
226206
Updates the recipe state from checkpoint.
@@ -274,7 +254,7 @@ def setup(self, cfg: DictConfig) -> None:
274254

275255
utils.log_rank_zero(self._logger, "metric logger is initialized.")
276256

277-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
257+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
278258

279259
self._model = self._setup_model(
280260
cfg_model=cfg.model,
@@ -286,7 +266,7 @@ def setup(self, cfg: DictConfig) -> None:
286266
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
287267
lora_weights_state_dict=(
288268
checkpoint_dict[training.ADAPTER_KEY]
289-
if self._resume_from_checkpoint
269+
if training.ADAPTER_KEY in checkpoint_dict
290270
else None
291271
),
292272
)
@@ -296,11 +276,38 @@ def setup(self, cfg: DictConfig) -> None:
296276
cfg_optimizer=cfg.optimizer,
297277
opt_state_dict=(
298278
checkpoint_dict[training.OPT_KEY]
299-
if self._resume_from_checkpoint
279+
if training.OPT_KEY in checkpoint_dict
300280
else None
301281
),
302282
)
303283

284+
if self._resume_from_checkpoint:
285+
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
286+
# using the DistributedCheckpointer.
287+
# Therefore the recipe needs to load the distributed checkpoint to restore the training
288+
# progress.
289+
if self._enable_async_checkpointing:
290+
try:
291+
checkpoint_dict = (
292+
self._checkpoint_client.load_distributed_checkpoint(
293+
self._model,
294+
self._optimizer,
295+
self._adapter_config,
296+
)
297+
)
298+
except Exception as e:
299+
log.warning(
300+
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
301+
)
302+
303+
if training.ADAPTER_KEY not in checkpoint_dict:
304+
raise ValueError(
305+
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
306+
)
307+
308+
# Update the recipe state from the checkpoint state dict.
309+
self._update_recipe_state(checkpoint_dict)
310+
304311
self._loss_fn = config.instantiate(cfg.loss)
305312

306313
utils.log_rank_zero(self._logger, "Loss is initialized.")
@@ -363,6 +370,17 @@ def _setup_model(
363370
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
364371
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
365372

373+
self._adapter_config = {
374+
"r": self._lora_rank,
375+
"lora_alpha": self._lora_alpha,
376+
"target_modules": get_lora_module_names(
377+
self._lora_attn_modules,
378+
self._apply_lora_to_mlp,
379+
self._apply_lora_to_output,
380+
),
381+
"peft_type": "LORA",
382+
}
383+
366384
init_start = time.perf_counter()
367385

368386
utils.log_rank_zero(
@@ -541,89 +559,20 @@ def save_checkpoint(
541559
self,
542560
epoch: int,
543561
) -> None:
544-
"""
545-
Checkpoint the state of the recipe. The constructed checkpoint state dict
546-
contains the following information:
547-
- Merged weights with key MODEL_KEY
548-
- Adapter weights with key ADAPTER_KEY
549-
- Relevant recipe state if training is not complete
550-
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
551-
552-
Checkpointer will save the merged weights, adapter weights and recipe state in
553-
different checkpoint files. To correctly resume from training, the adapter weights
554-
and recipe state must be provided along with the base model weights."""
555-
# final dict passed onto the checkpointer
556-
checkpoint_dict = {}
557-
558-
intermediate_checkpoint = epoch + 1 < self.total_epochs
559-
# To prevent GPU memory from spiking during checkpoint save,
560-
# we consolidate the full model and optim state dicts on CPU for rank 0
561-
cpu_state_dict = training.gather_cpu_state_dict(
562-
self._model,
563-
self._is_rank_zero,
564-
device=self._device,
565-
adapter_weights_only=self._save_adapter_weights_only,
562+
self._checkpoint_client.save_checkpoint(
563+
model=self._model,
564+
optimizer=self._optimizer,
565+
training_progress=TrainingProgress(
566+
seed=self.seed,
567+
epochs_run=self.epochs_run,
568+
total_epochs=self.total_epochs,
569+
max_steps_per_epoch=self.max_steps_per_epoch,
570+
dataloader_state_dict=self._dataloader.state_dict(),
571+
),
572+
epoch=epoch,
573+
adapter_config=self._adapter_config.copy(),
574+
adapter_only=self._save_adapter_weights_only,
566575
)
567-
if intermediate_checkpoint:
568-
opt_state_dict = training.get_full_optimizer_state_dict(
569-
self._model,
570-
self._optimizer,
571-
self._is_rank_zero,
572-
device=self._device,
573-
)
574-
else:
575-
opt_state_dict = None
576-
577-
# Now that we have the model and opt state dict, create the actual checkpoint dict
578-
# to be sent to the checkpointer and ultimately written to file
579-
if self._is_rank_zero:
580-
if self._save_adapter_weights_only:
581-
adapter_state_dict = cpu_state_dict
582-
else:
583-
# Filter out the adapter keys and weights from the model state dict. These will
584-
# be saved separately
585-
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
586-
587-
# merge the adapter weights and base weights to create the model checkpoint
588-
merged_state_dict = get_merged_lora_ckpt(
589-
cpu_state_dict,
590-
rank=self._lora_rank,
591-
alpha=self._lora_alpha,
592-
)
593-
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
594-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
595-
596-
# if training is in-progress, checkpoint the optimizer state and recipe state
597-
# as well.
598-
if intermediate_checkpoint:
599-
checkpoint_dict.update(
600-
{
601-
training.OPT_KEY: opt_state_dict,
602-
training.SEED_KEY: self.seed,
603-
training.EPOCHS_KEY: self.epochs_run,
604-
training.TOTAL_EPOCHS_KEY: self.total_epochs,
605-
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
606-
training.DATALOADER_KEY: self._dataloader.state_dict(),
607-
}
608-
)
609-
610-
adapter_config = {
611-
"r": self._lora_rank,
612-
"lora_alpha": self._lora_alpha,
613-
"target_modules": get_lora_module_names(
614-
self._lora_attn_modules,
615-
self._apply_lora_to_mlp,
616-
self._apply_lora_to_output,
617-
),
618-
"peft_type": "LORA",
619-
}
620-
checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
621-
self._checkpointer.save_checkpoint(
622-
checkpoint_dict,
623-
epoch=epoch,
624-
intermediate_checkpoint=intermediate_checkpoint,
625-
adapter_only=self._save_adapter_weights_only,
626-
)
627576

628577
def concatenated_forward(
629578
self, model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]

tests/recipes/test_lora_dpo_distributed.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,93 @@ def test_training_state_on_resume(
143143
resumed_loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
144144
)
145145

146+
@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
147+
@gpu_test(gpu_count=4)
148+
@pytest.mark.integration_test
149+
def test_training_state_on_resume(
150+
self, tmpdir, monkeypatch, save_adapter_weights_only
151+
):
152+
"""
153+
Same as above with async checkpointing enabled.
154+
"""
155+
156+
ckpt = "llama2_hf"
157+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
158+
ckpt_dir = ckpt_path.parent
159+
log_file = gen_log_file_name(tmpdir)
160+
161+
# Config file needed for model conversion.
162+
# Create a second copy for training resume
163+
write_hf_ckpt_config(ckpt_dir)
164+
write_hf_ckpt_config(tmpdir)
165+
166+
# Train for two epochs
167+
cmd_1 = f"""
168+
tune run --nnodes 1 --nproc_per_node 4 lora_dpo_distributed \
169+
--config llama2/7B_lora_dpo \
170+
output_dir={tmpdir} \
171+
model.lora_attn_modules=['q_proj','v_proj'] \
172+
model.apply_lora_to_mlp=False \
173+
checkpointer=torchtune.training.FullModelHFCheckpointer \
174+
checkpointer.checkpoint_dir='{ckpt_dir}' \
175+
checkpointer.checkpoint_files=[{ckpt_path}]\
176+
checkpointer.output_dir={tmpdir} \
177+
checkpointer.model_type=LLAMA2 \
178+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
179+
tokenizer.prompt_template=null \
180+
save_adapter_weights_only={save_adapter_weights_only} \
181+
metric_logger.filename={log_file} \
182+
enable_activation_checkpointing=True \
183+
enable_activation_offloading=False \
184+
enable_async_checkpointing=True \
185+
""".split()
186+
187+
model_config = MODEL_TEST_CONFIGS["llama2_lora"]
188+
189+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
190+
monkeypatch.setattr(sys, "argv", cmd_1)
191+
runpy.run_path(TUNE_PATH, run_name="__main__")
192+
193+
expected_loss_values = get_loss_values_from_metric_logger(log_file)
194+
195+
resumed_log_dir = (tmpdir / "resumed/").mkdir()
196+
resumed_log_file = gen_log_file_name(resumed_log_dir)
197+
198+
# Resume training
199+
epoch_folder = get_largest_iter_folder(tmpdir)
200+
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
201+
cmd_2 = f"""
202+
tune run --nnodes 1 --nproc_per_node 4 lora_dpo_distributed \
203+
--config llama2/7B_lora_dpo \
204+
output_dir={tmpdir} \
205+
model.lora_attn_modules=['q_proj','v_proj'] \
206+
model.apply_lora_to_mlp=False \
207+
checkpointer=torchtune.training.FullModelHFCheckpointer \
208+
checkpointer.checkpoint_dir={ckpt_dir} \
209+
checkpointer.checkpoint_files=[{ckpt_path}]\
210+
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
211+
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
212+
checkpointer.output_dir={tmpdir} \
213+
checkpointer.model_type=LLAMA2 \
214+
resume_from_checkpoint=True \
215+
metric_logger.filename={resumed_log_file} \
216+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
217+
tokenizer.prompt_template=null \
218+
enable_activation_checkpointing=True \
219+
enable_activation_offloading=False \
220+
enable_async_checkpointing=True \
221+
""".split()
222+
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
223+
monkeypatch.setattr(sys, "argv", cmd_2)
224+
runpy.run_path(TUNE_PATH, run_name="__main__")
225+
226+
# Second epoch only
227+
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)
228+
229+
torch.testing.assert_close(
230+
resumed_loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
231+
)
232+
146233
@pytest.mark.integration_test
147234
@gpu_test(gpu_count=2)
148235
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):

0 commit comments

Comments
 (0)