Skip to content

Update get_merged_lora_ckpt for dist checkpoints #2834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def setup(self, cfg: DictConfig) -> None:
ckpt_dict = self._checkpoint_client.load_distributed_checkpoint(
self._model,
self.optimizer,
single_device=True,
)
except Exception as e:
self._logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model,
self._optimizer,
self._adapter_config,
self._save_adapter_weights_only,
single_device=True,
)

if training.ADAPTER_KEY not in checkpoint_dict:
Expand Down
1 change: 1 addition & 0 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model,
self._optimizer,
self._adapter_config,
single_device=True,
)

if training.ADAPTER_KEY not in checkpoint_dict:
Expand Down
1 change: 1 addition & 0 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model,
self._optimizer,
self._adapter_config,
single_device=True,
)

if training.ADAPTER_KEY not in checkpoint_dict:
Expand Down
1 change: 1 addition & 0 deletions torchtune/modules/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
get_merged_lora_dist_ckpt,
LORA_ATTN_MODULES,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
Expand Down
118 changes: 118 additions & 0 deletions torchtune/modules/peft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Generator, Literal, Optional, Protocol, runtime_checkable, Union

import torch
import torch.distributed as dist
from torch import nn
from torchtune.utils._logging import deprecate_parameter

Expand Down Expand Up @@ -256,6 +257,123 @@ def get_merged_lora_ckpt(
return state_dict


@torch.no_grad
def get_merged_lora_dist_ckpt(
state_dict: dict[str, Any],
rank: int,
alpha: float,
) -> dict[str, Any]:
"""
Merge LoRA weights into the base model format for efficient inference using distributed operations.
This function is designed for distributed training scenarios and uses distributed operations
like distributed matrix multiplication.
NOTE: This function modifies state_dict inplace. If you do not want to do that,
make a copy prior to calling this function.
NOTE: This does not work for NF4Tensors as they don't support the add and mul operations used here.
For every LoRA module in the state dict, this function will convert its
base weight then delete the LoRA-specific parameters.
Args:
state_dict (dict[str, Any]): State dict from a model.
rank (int): The rank of LoRA matrices.
alpha (float): The alpha value used for scaling LoRA decompositions.
Returns:
dict[str, Any]: The merged state dict.
"""

lora_modules = _get_lora_modules(state_dict)
lora_moe_modules = _get_lora_moe_modules(state_dict)

# Create a simple module for matrix multiplication
class MatMulModule(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this instead of just calling matmul directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these operations don't work properly on d-tensors. it's what was causing the hangs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"doesn't work properly" - can you expand on that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok actually, I think it's not needed. Good catch. I thought it was causing problems, but I just re-tested without it and it still works. I'll get rid of them and just add barriers to the existing method.

def forward(self, x, y):
return (alpha / rank) * torch.matmul(x, y)

for module in sorted(lora_modules.union(lora_moe_modules)):
# TODO: we don't currently support DoRA for MoE layers
if "experts" in module:
for param in ["gate", "up", "down"]:
lora_a_weight = state_dict[f"{module}.lora_{param}_a"]
lora_b_weight = state_dict[f"{module}.lora_{param}_b"]

# Create a simple module for transpose operation
class TransposeModule(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: why does this need to be a transpose module?

def __init__(self, dim0, dim1):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1

def forward(self, x):
return torch.transpose(x, self.dim0, self.dim1)

# Parallelize transpose operations
transpose_module = TransposeModule(1, 2)
dist.barrier()
# Apply distributed transpose
transposed_b = transpose_module(lora_b_weight)
transposed_a = transpose_module(lora_a_weight)

mm_module = MatMulModule()
dist.barrier()
result = mm_module(transposed_b, transposed_a)

# Apply the result using out-of-place addition
proj_weight = state_dict[f"{module}.{param}_proj"]

dist.barrier()
transposed_result = transpose_module(result)

state_dict[f"{module}.{param}_proj"] = proj_weight + transposed_result

del state_dict[f"{module}.lora_{param}_a"]
del state_dict[f"{module}.lora_{param}_b"]
continue

lora_a_weight = state_dict[f"{module}.lora_a.weight"]
lora_b_weight = state_dict[f"{module}.lora_b.weight"]
lora_magnitude = state_dict.get(f"{module}.magnitude", None)

# If magnitude is present, calculate merged DoRA weight
if lora_magnitude is not None:
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype)

mm_module = MatMulModule()
dist.barrier()
lora_weight = mm_module(lora_b_weight, lora_a_weight)

merged_weight = base_weight + lora_weight
dist.barrier()

# Create a simple module for norm calculation
class NormModule(torch.nn.Module):
def forward(self, x):
return torch.linalg.norm(x, dim=1)

norm_module = NormModule()
dist.barrier()
weight_norm = norm_module(merged_weight)

mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1)
merged_weight *= mag_norm_scale
state_dict[f"{module}.weight"] = merged_weight
del state_dict[f"{module}.magnitude"]

# Otherwise it is just vanilla LoRA
else:
mm_module = MatMulModule()
dist.barrier()
lora_weight = mm_module(
lora_b_weight,
lora_a_weight,
)
state_dict[f"{module}.weight"] += lora_weight

del state_dict[f"{module}.lora_a.weight"]
del state_dict[f"{module}.lora_b.weight"]

dist.barrier()
return state_dict


@contextlib.contextmanager
def disable_adapter(model: nn.Module) -> Generator[None, None, None]:
"""
Expand Down
57 changes: 39 additions & 18 deletions torchtune/training/checkpointing/_checkpoint_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchtune.modules.peft import (
get_adapter_state_dict,
get_merged_lora_ckpt,
get_merged_lora_dist_ckpt,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.checkpointing._checkpointer import DistributedCheckpointer
Expand Down Expand Up @@ -130,6 +131,7 @@ def _save_checkpoint_async(
epoch: int,
adapter_config: Optional[dict[str, Any]],
adapter_only: bool,
single_device: bool,
) -> None:
"""
Checkpoint the training state asynchronously as a distributed checkpoint. Saving
Expand Down Expand Up @@ -166,22 +168,25 @@ def _save_checkpoint_async(
}
)

get_merged_lora_ckpt(
ckpt_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)
if single_device:
get_merged_lora_ckpt(
ckpt_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)
else:
get_merged_lora_dist_ckpt(
ckpt_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)

dcp_saver = self._get_dcp_checkpointer()
if not adapter_only:
dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True)

if self._is_rank_zero:
log.info(
f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs"
)

if adapter_config is not None:
# save adapter weights first because it is faster
# so will block training for less time
# because you can only do async checkpointing one at a time
adapter_start = time.perf_counter()

save_path = dcp_saver.get_output_path(epoch=epoch)
Expand All @@ -205,6 +210,14 @@ def _save_checkpoint_async(
f"Saving asynchronous checkpoint for adapter weights took {time.perf_counter() - adapter_start:.2f} secs"
)

if not adapter_only:
dcp_saver.save_checkpoint(ckpt_dict, epoch=epoch, save_async=True)

if self._is_rank_zero:
log.info(
f"Saving asynchronous checkpoint took {time.perf_counter() - cp_start:.2f} secs"
)

def _save_checkpoint_sync(
self,
model: torch.nn.Module,
Expand Down Expand Up @@ -359,7 +372,6 @@ def save_checkpoint(
checkpointer user has configured.
"""
intermediate_checkpoint = epoch + 1 < training_progress.total_epochs

if intermediate_checkpoint and self._enable_async_checkpointing:
self._save_checkpoint_async(
model,
Expand All @@ -368,6 +380,7 @@ def save_checkpoint(
epoch,
adapter_config,
adapter_only,
single_device,
)
else:
self._save_checkpoint_sync(
Expand All @@ -392,6 +405,7 @@ def load_distributed_checkpoint(
model: torch.nn.Module,
optimizer: Union[torch.optim.Optimizer, OptimizerInBackwardWrapper],
adapter_config: Optional[dict[str, Any]] = None,
single_device: bool = False,
) -> dict[str, Any]:
"""
This method is used to resume training from a distributed checkpoint state.
Expand Down Expand Up @@ -438,11 +452,18 @@ def load_distributed_checkpoint(
}
)

get_merged_lora_ckpt(
checkpoint_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)
if single_device:
get_merged_lora_ckpt(
checkpoint_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)
else:
get_merged_lora_dist_ckpt(
checkpoint_dict[training.MODEL_KEY],
adapter_config["r"],
adapter_config["lora_alpha"],
)

adapter_only = False
dcp_checkpointer = self._get_dcp_checkpointer()
Expand Down
Loading