-
Notifications
You must be signed in to change notification settings - Fork 657
Update get_merged_lora_ckpt for dist checkpoints #2834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ankitageorge
merged 8 commits into
pytorch:main
from
ankitageorge:fix-dist-merged-weights
Jun 23, 2025
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
bca098d
dist merge
ankitageorge 172174c
add single device as arg
ankitageorge 72d5636
fix vanilla lora
ankitageorge b0f211e
change order or checkpoint save
ankitageorge 00be546
fix load too
ankitageorge 2d2e06a
Merge branch 'main' into fix-dist-merged-weights
ankitageorge 6e29202
separate dist method not needed
ankitageorge 9a0a6eb
fix config
ankitageorge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from typing import Any, Generator, Literal, Optional, Protocol, runtime_checkable, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch import nn | ||
from torchtune.utils._logging import deprecate_parameter | ||
|
||
|
@@ -256,6 +257,123 @@ def get_merged_lora_ckpt( | |
return state_dict | ||
|
||
|
||
@torch.no_grad | ||
def get_merged_lora_dist_ckpt( | ||
state_dict: dict[str, Any], | ||
rank: int, | ||
alpha: float, | ||
) -> dict[str, Any]: | ||
""" | ||
Merge LoRA weights into the base model format for efficient inference using distributed operations. | ||
This function is designed for distributed training scenarios and uses distributed operations | ||
like distributed matrix multiplication. | ||
NOTE: This function modifies state_dict inplace. If you do not want to do that, | ||
make a copy prior to calling this function. | ||
NOTE: This does not work for NF4Tensors as they don't support the add and mul operations used here. | ||
For every LoRA module in the state dict, this function will convert its | ||
base weight then delete the LoRA-specific parameters. | ||
Args: | ||
state_dict (dict[str, Any]): State dict from a model. | ||
rank (int): The rank of LoRA matrices. | ||
alpha (float): The alpha value used for scaling LoRA decompositions. | ||
Returns: | ||
dict[str, Any]: The merged state dict. | ||
""" | ||
|
||
lora_modules = _get_lora_modules(state_dict) | ||
lora_moe_modules = _get_lora_moe_modules(state_dict) | ||
|
||
# Create a simple module for matrix multiplication | ||
class MatMulModule(torch.nn.Module): | ||
def forward(self, x, y): | ||
return (alpha / rank) * torch.matmul(x, y) | ||
|
||
for module in sorted(lora_modules.union(lora_moe_modules)): | ||
# TODO: we don't currently support DoRA for MoE layers | ||
if "experts" in module: | ||
for param in ["gate", "up", "down"]: | ||
lora_a_weight = state_dict[f"{module}.lora_{param}_a"] | ||
lora_b_weight = state_dict[f"{module}.lora_{param}_b"] | ||
|
||
# Create a simple module for transpose operation | ||
class TransposeModule(torch.nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here: why does this need to be a transpose module? |
||
def __init__(self, dim0, dim1): | ||
super().__init__() | ||
self.dim0 = dim0 | ||
self.dim1 = dim1 | ||
|
||
def forward(self, x): | ||
return torch.transpose(x, self.dim0, self.dim1) | ||
|
||
# Parallelize transpose operations | ||
transpose_module = TransposeModule(1, 2) | ||
dist.barrier() | ||
# Apply distributed transpose | ||
transposed_b = transpose_module(lora_b_weight) | ||
transposed_a = transpose_module(lora_a_weight) | ||
|
||
mm_module = MatMulModule() | ||
dist.barrier() | ||
result = mm_module(transposed_b, transposed_a) | ||
|
||
# Apply the result using out-of-place addition | ||
proj_weight = state_dict[f"{module}.{param}_proj"] | ||
|
||
dist.barrier() | ||
transposed_result = transpose_module(result) | ||
|
||
state_dict[f"{module}.{param}_proj"] = proj_weight + transposed_result | ||
|
||
del state_dict[f"{module}.lora_{param}_a"] | ||
del state_dict[f"{module}.lora_{param}_b"] | ||
continue | ||
|
||
lora_a_weight = state_dict[f"{module}.lora_a.weight"] | ||
lora_b_weight = state_dict[f"{module}.lora_b.weight"] | ||
lora_magnitude = state_dict.get(f"{module}.magnitude", None) | ||
|
||
# If magnitude is present, calculate merged DoRA weight | ||
if lora_magnitude is not None: | ||
base_weight = state_dict[f"{module}.weight"].to(lora_a_weight.dtype) | ||
|
||
mm_module = MatMulModule() | ||
dist.barrier() | ||
lora_weight = mm_module(lora_b_weight, lora_a_weight) | ||
|
||
merged_weight = base_weight + lora_weight | ||
dist.barrier() | ||
|
||
# Create a simple module for norm calculation | ||
class NormModule(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.linalg.norm(x, dim=1) | ||
|
||
norm_module = NormModule() | ||
dist.barrier() | ||
weight_norm = norm_module(merged_weight) | ||
|
||
mag_norm_scale = (lora_magnitude / weight_norm).view(-1, 1) | ||
merged_weight *= mag_norm_scale | ||
state_dict[f"{module}.weight"] = merged_weight | ||
del state_dict[f"{module}.magnitude"] | ||
|
||
# Otherwise it is just vanilla LoRA | ||
else: | ||
mm_module = MatMulModule() | ||
dist.barrier() | ||
lora_weight = mm_module( | ||
lora_b_weight, | ||
lora_a_weight, | ||
) | ||
state_dict[f"{module}.weight"] += lora_weight | ||
|
||
del state_dict[f"{module}.lora_a.weight"] | ||
del state_dict[f"{module}.lora_b.weight"] | ||
|
||
dist.barrier() | ||
return state_dict | ||
|
||
|
||
@contextlib.contextmanager | ||
def disable_adapter(model: nn.Module) -> Generator[None, None, None]: | ||
""" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this instead of just calling matmul directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these operations don't work properly on d-tensors. it's what was causing the hangs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"doesn't work properly" - can you expand on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok actually, I think it's not needed. Good catch. I thought it was causing problems, but I just re-tested without it and it still works. I'll get rid of them and just add barriers to the existing method.