Skip to content

Add DCP async checkpointing for lora dpo recipe #2835

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 60 additions & 112 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
AdapterModule,
disable_adapter,
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.rlhf import ChosenRejectedOutputs
from torchtune.training import VALID_BACKENDS_FOR_MEMORY_STATS
from torchtune.training.checkpointing._checkpoint_client import (
CheckpointClient,
TrainingProgress,
)
from tqdm import tqdm


Expand Down Expand Up @@ -144,6 +146,8 @@ def __init__(self, cfg: DictConfig) -> None:

self._is_rank_zero = self.rank == 0

self._checkpoint_client = CheckpointClient(cfg)

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down Expand Up @@ -196,31 +200,6 @@ def __init__(self, cfg: DictConfig) -> None:
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> dict[str, Any]:
"""
Extract the checkpoint state from file and validate. This includes the
base model weights. If resume_from_checkpoint is True, this also includes
the adapter weights and recipe state
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
should_load_recipe_state=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()

# When resuming from checkpoint for LoRA, the recipe expects the adapter weights
# and recipe state to be present. The keys should match up with what ``save_checkpoint``
# used to create these intermediate checkpoints
if self._resume_from_checkpoint:
if training.ADAPTER_KEY not in checkpoint_dict:
raise ValueError(
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
)
# _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
# no need to check here
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict

def _update_recipe_state(self, ckpt_dict: dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
Expand Down Expand Up @@ -274,7 +253,7 @@ def setup(self, cfg: DictConfig) -> None:

utils.log_rank_zero(self._logger, "metric logger is initialized.")

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()

self._model = self._setup_model(
cfg_model=cfg.model,
Expand All @@ -286,7 +265,7 @@ def setup(self, cfg: DictConfig) -> None:
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
checkpoint_dict[training.ADAPTER_KEY]
if self._resume_from_checkpoint
if training.ADAPTER_KEY in checkpoint_dict
else None
),
)
Expand All @@ -296,11 +275,38 @@ def setup(self, cfg: DictConfig) -> None:
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if self._resume_from_checkpoint
if training.OPT_KEY in checkpoint_dict
else None
),
)

if self._resume_from_checkpoint:
# If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
# using the DistributedCheckpointer.
# Therefore the recipe needs to load the distributed checkpoint to restore the training
# progress.
if self._enable_async_checkpointing:
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
self._optimizer,
self._adapter_config,
)
)
except Exception as e:
self._logger.warning(
f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
)

if training.ADAPTER_KEY not in checkpoint_dict:
raise ValueError(
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
)

# Update the recipe state from the checkpoint state dict.
self._update_recipe_state(checkpoint_dict)

self._loss_fn = config.instantiate(cfg.loss)

utils.log_rank_zero(self._logger, "Loss is initialized.")
Expand Down Expand Up @@ -363,6 +369,17 @@ def _setup_model(
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)

self._adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}

init_start = time.perf_counter()

utils.log_rank_zero(
Expand Down Expand Up @@ -541,89 +558,20 @@ def save_checkpoint(
self,
epoch: int,
) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights

Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights."""
# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = training.gather_cpu_state_dict(
self._model,
self._is_rank_zero,
device=self._device,
adapter_weights_only=self._save_adapter_weights_only,
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=self._optimizer,
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
total_epochs=self.total_epochs,
max_steps_per_epoch=self.max_steps_per_epoch,
dataloader_state_dict=self._dataloader.state_dict(),
),
epoch=epoch,
adapter_config=self._adapter_config.copy(),
adapter_only=self._save_adapter_weights_only,
)
if intermediate_checkpoint:
opt_state_dict = training.get_full_optimizer_state_dict(
self._model,
self._optimizer,
self._is_rank_zero,
device=self._device,
)
else:
opt_state_dict = None

# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:
if self._save_adapter_weights_only:
adapter_state_dict = cpu_state_dict
else:
# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
checkpoint_dict.update(
{
training.OPT_KEY: opt_state_dict,
training.SEED_KEY: self.seed,
training.EPOCHS_KEY: self.epochs_run,
training.TOTAL_EPOCHS_KEY: self.total_epochs,
training.MAX_STEPS_KEY: self.max_steps_per_epoch,
training.DATALOADER_KEY: self._dataloader.state_dict(),
}
)

adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
adapter_only=self._save_adapter_weights_only,
)

def concatenated_forward(
self, model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]
Expand Down
87 changes: 87 additions & 0 deletions tests/recipes/test_lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,93 @@ def test_training_state_on_resume(
resumed_loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.parametrize("save_adapter_weights_only", [False, True])
@gpu_test(gpu_count=4)
@pytest.mark.integration_test
def test_training_state_on_resume_with_async_checkpointing(
self, tmpdir, monkeypatch, save_adapter_weights_only
):
"""
Same as above with async checkpointing enabled.
"""

ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
log_file = gen_log_file_name(tmpdir)

# Config file needed for model conversion.
# Create a second copy for training resume
write_hf_ckpt_config(ckpt_dir)
write_hf_ckpt_config(tmpdir)

# Train for two epochs
cmd_1 = f"""
tune run --nnodes 1 --nproc_per_node 4 lora_dpo_distributed \
--config llama2/7B_lora_dpo \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
save_adapter_weights_only={save_adapter_weights_only} \
metric_logger.filename={log_file} \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=True \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2_lora"]

cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
monkeypatch.setattr(sys, "argv", cmd_1)
runpy.run_path(TUNE_PATH, run_name="__main__")

expected_loss_values = get_loss_values_from_metric_logger(log_file)

resumed_log_dir = (tmpdir / "resumed/").mkdir()
resumed_log_file = gen_log_file_name(resumed_log_dir)

# Resume training
epoch_folder = get_largest_iter_folder(tmpdir)
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 4 lora_dpo_distributed \
--config llama2/7B_lora_dpo \
output_dir={tmpdir} \
model.lora_attn_modules=['q_proj','v_proj'] \
model.apply_lora_to_mlp=False \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir={ckpt_dir} \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.adapter_checkpoint={os.path.join(tmpdir, epoch_folder_minus_one, f"{ADAPTER_MODEL_FNAME}.pt")}
checkpointer.recipe_checkpoint={os.path.join(tmpdir, RECIPE_STATE_DIRNAME, "recipe_state.pt")}
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
resume_from_checkpoint=True \
metric_logger.filename={resumed_log_file} \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
enable_activation_checkpointing=True \
enable_activation_offloading=False \
enable_async_checkpointing=True \
""".split()
cmd_2 = cmd_2 + self._get_test_config_overrides(epochs=3) + model_config
monkeypatch.setattr(sys, "argv", cmd_2)
runpy.run_path(TUNE_PATH, run_name="__main__")

# Second epoch only
resumed_loss_values = get_loss_values_from_metric_logger(resumed_log_file)

torch.testing.assert_close(
resumed_loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
)

@pytest.mark.integration_test
@gpu_test(gpu_count=2)
def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
Expand Down
13 changes: 13 additions & 0 deletions torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,19 @@ def _save_checkpoint_sync(
optim_state_dict = {}

if is_not_distributed_checkpointer and not single_device:
# this logic is needed because staging an async checkpoint needs cpu
# which is also used here to save a sync checkpoint that causes issues when
# occurring concurrently. We should wait for async checkpoint to clear
# before saving a sync checkpoint that requires cpu gathering.
if self._get_dcp_checkpointer()._checkpoint_future is not None:
time_start_waiting = time.perf_counter()
self._get_dcp_checkpointer()._checkpoint_future.result()
if self._is_rank_zero:
log.info(
"Waiting for async checkpoint to finish, to save sync checkpoint ",
f"took {time.perf_counter() - time_start_waiting:.2f} secs",
)

# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
model_state_dict = training.gather_cpu_state_dict(
Expand Down
Loading