Skip to content

Commit 4346513

Browse files
ENH: Targeting multiple parameters on the same module (#2665)
When the target_parameters feature for LoRA was introduced in #2638, there was one gap, namely the possibility to target multiple nn.Parameters on the same module (there was only a workaround involving multiple adapters, but that is not user friendly). With this PR, it is now possible to achieve this. The mechanism to enable this is a bit crude, namely allowing to nest multiple ParamWrappers. This should generally be fine as long as there are only a couple of nn.Parameters being targeted on the same module. When there are dozens or hundreds, this approach could load to slow downs or other issues. A side effect of this implementation is that the ParamWrapper, when it removes the parametrization, now only removes its own parametrization. When using nn.utils.parametrize.remove_parametrization, it removes all parametrizations, which is bad when we have nested parametrizations. Alternative approaches Some alternative approaches were discussed internally but the chosen one was considered most practical. Allow to have more than one adapted parameter per LoRA layer. This would require to have nested dicts for the LoRA parameters, something like self.lora_A[adapter_name][parameter_name]. We don't have this anywhere so far and it would probably break implicit assumptions about PEFT layers in many places (like, parsing of state_dict keys), requiring many adjustments. Have an auxiliary module that contains the individual LoRA layers that target the individual parameters. This could be the cleanest solution and would probably be more efficient if there are a huge number of targeted parameters per module. However, this also brings extra complexity, as it requires implementing the logic of how to route the information to the right parameter, and it may be a solution to a problem that is irrelevant in practice (large number of targets per module).
1 parent 43845f9 commit 4346513

File tree

4 files changed

+132
-67
lines changed

4 files changed

+132
-67
lines changed

src/peft/tuners/lora/layer.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,15 @@ def forward(self, W):
17771777
return W + self.delta_weight
17781778

17791779

1780+
# copied from:
1781+
# https://github.com/pytorch/pytorch/blob/5e386eec9426f174eea130c0c012d9f65ebe65fb/torch/nn/utils/parametrize.py#L75-L79
1782+
def _register_parameter_or_buffer(module, name, X):
1783+
if isinstance(X, nn.Parameter):
1784+
module.register_parameter(name, X)
1785+
else:
1786+
module.register_buffer(name, X)
1787+
1788+
17801789
class ParamWrapper(nn.Module, LoraLayer):
17811790
"""A LoRA wrapper for `nn.Parameter`. This layer is dispatched if users target a parameter directly with
17821791
`lora_config.target_parameters`
@@ -1807,8 +1816,8 @@ def __init__(
18071816
) -> None:
18081817
super().__init__()
18091818
LoraLayer.__init__(self, base_layer, **kwargs)
1810-
param = getattr(base_layer, parameter_name)
18111819
self.parameter_name = parameter_name
1820+
param = self.get_param()
18121821
if param.ndim == 3:
18131822
self.num_experts, self.in_features, self.out_features = param.shape
18141823
else:
@@ -1867,15 +1876,6 @@ def update_layer(
18671876
# This code works for linear layers, override for other layer types
18681877
if r <= 0:
18691878
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
1870-
if adapter_name in self.lora_A:
1871-
# It is not allowed to target multiple parameters on the same module. Supporting this would complicate
1872-
# things quite a lot, since we would require multiple self.lora_A, self.lora_B, etc., one for each targeted
1873-
# parameter.
1874-
raise ValueError(
1875-
f"lora.{self.__class__.__name__} already has an adapter for parameter '{self.parameter_name}'. "
1876-
"It is currently not possible to apply the same adapter to multiple parameters, please add a "
1877-
"different adapter to target another parameter of the same module."
1878-
)
18791879

18801880
lora_variant = self.resolve_lora_variant(
18811881
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
@@ -1958,7 +1958,8 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
19581958
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)
19591959

19601960
def get_param(self):
1961-
return getattr(self.base_layer, self.parameter_name)
1961+
param = getattr(self.get_base_layer(), self.parameter_name)
1962+
return param
19621963

19631964
def get_delta_weight(self, adapter_name, *args, **kwargs):
19641965
if self.num_experts == 1:
@@ -2004,10 +2005,26 @@ def _activate_lora(self, active_adapters: list[str]):
20042005
try:
20052006
yield
20062007
finally:
2007-
nn.utils.parametrize.remove_parametrizations(
2008-
self.base_layer, self.parameter_name, leave_parametrized=False
2008+
self._remove_parametrizations()
2009+
2010+
def _remove_parametrizations(self):
2011+
# Remove the parametrization of this specific parameter
2012+
base_layer = self.get_base_layer()
2013+
parameter_name = self.parameter_name
2014+
if parameter_name not in base_layer.parametrizations:
2015+
raise ValueError(
2016+
"Something went wrong, please report this issue on PEFT: https://github.com/huggingface/peft/issues"
20092017
)
20102018

2019+
if len(base_layer.parametrizations[parameter_name]) == 1:
2020+
# last parametrization, we can safely remove it completely
2021+
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
2022+
else:
2023+
# TODO: If there are multiple parametrizations for the same parameter_name, we currently remove all of them,
2024+
# which is not desired. Unfortunately, PyTorch does not support this directly, so we need to take care.
2025+
# For now, remove all parametrizations.
2026+
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
2027+
20112028
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
20122029
# same as lora.Linear.merge but not hard-coding base_layer.weight and without special cases like variants removed
20132030
adapter_names = check_adapters_to_merge(self, adapter_names)
@@ -2059,6 +2076,18 @@ def _check_forward_args(self, x, *args, **kwargs):
20592076
raise ValueError(f"lora.{self.__class__.__name__} does not support mixed adapter batches yet.")
20602077
super()._check_forward_args(x, *args, **kwargs)
20612078

2079+
def unload_and_optionally_merge_module(self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]]):
2080+
base_layer = self.base_layer
2081+
# ParamWrappers can be nested, so merge and retrieve base layer recursively
2082+
if merge:
2083+
self.merge(safe_merge=safe_merge, adapter_names=adapter_names)
2084+
while isinstance(base_layer, ParamWrapper):
2085+
base_layer.merge(safe_merge=safe_merge, adapter_names=adapter_names)
2086+
base_layer = base_layer.base_layer
2087+
else:
2088+
base_layer = self.get_base_layer()
2089+
return base_layer
2090+
20622091
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
20632092
self._check_forward_args(x, *args, **kwargs)
20642093
adapter_names = kwargs.pop("adapter_names", None)

src/peft/tuners/lora/model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from .gptq import dispatch_gptq
5454
from .hqq import dispatch_hqq
5555
from .inc import dispatch_inc
56-
from .layer import Conv2d, LoraLayer, dispatch_default
56+
from .layer import Conv2d, LoraLayer, ParamWrapper, dispatch_default
5757
from .torchao import dispatch_torchao
5858
from .tp_layer import dispatch_megatron
5959

@@ -227,7 +227,9 @@ def _create_and_replace(
227227
# note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
228228
from peft.tuners.adalora import AdaLoraLayer
229229

230-
if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
230+
# if the target is a ParamWrapper, we nest it to allow targeting multiple nn.Parameter on the same module
231+
wrap_target_param = isinstance(target, ParamWrapper) and (adapter_name in target.lora_A)
232+
if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer) and not wrap_target_param:
231233
target.update_layer(
232234
adapter_name,
233235
r,
@@ -239,6 +241,11 @@ def _create_and_replace(
239241
lora_bias=lora_config.lora_bias,
240242
)
241243
else:
244+
if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name):
245+
raise ValueError(
246+
"Trying to target the same nn.Parameter twice, this should not happen. Please open an issue on the "
247+
"PEFT repo: https://github.com/huggingface/peft/issues"
248+
)
242249
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
243250
new_module = self._create_new_module(lora_config, adapter_name, target, device_map=device_map, **kwargs)
244251
if adapter_name not in self.active_adapters:

tests/test_initialization.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,30 +1411,6 @@ def __init__(self):
14111411
with pytest.raises(ValueError, match=msg):
14121412
get_peft_model(base_model, config)
14131413

1414-
def test_targeting_2_params_on_1_module_raises(self):
1415-
# It is currently not supported to target multiple parameters on the same module.
1416-
class ModuleWith2Params(nn.Module):
1417-
def __init__(self, in_features, out_features):
1418-
super().__init__()
1419-
self.weight0 = nn.Parameter(torch.zeros(in_features, out_features))
1420-
self.weight1 = nn.Parameter(torch.ones(3, out_features, out_features))
1421-
1422-
class Outer(nn.Module):
1423-
def __init__(self, in_features, out_features):
1424-
super().__init__()
1425-
self.lin = nn.Linear(in_features, in_features)
1426-
self.submodule = ModuleWith2Params(in_features, out_features)
1427-
1428-
model = Outer(3, 4)
1429-
config = LoraConfig(target_parameters=["submodule.weight0", "submodule.weight1"], init_lora_weights=False)
1430-
msg = (
1431-
"lora.ParamWrapper already has an adapter for parameter 'weight0'. It is currently not possible to apply "
1432-
"the same adapter to multiple parameters, please add a different adapter to target another parameter of "
1433-
"the same module."
1434-
)
1435-
with pytest.raises(ValueError, match=msg):
1436-
get_peft_model(model, config)
1437-
14381414
@pytest.mark.parametrize("target_parameters", [["linear"], ["foobar"], ["foobar.weight"], ["foo", "bar"]])
14391415
@pytest.mark.parametrize("target_modules", [None, [], ""])
14401416
def test_valid_no_target_module_nor_target_parameter_match_raises(self, target_parameters, target_modules):

tests/test_target_parameters.py

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@
5959
],
6060
},
6161
),
62+
# target down_proj and gate_up_proj on the same module
63+
(
64+
LoraConfig,
65+
{
66+
"task_type": "CAUSAL_LM",
67+
"r": 8,
68+
"lora_alpha": 32,
69+
"target_modules": None,
70+
"lora_dropout": 0.0,
71+
"bias": "none",
72+
"target_parameters": [
73+
"feed_forward.experts.down_proj",
74+
"feed_forward.experts.gate_up_proj",
75+
],
76+
},
77+
),
6278
# target q_proj, v_proj as modules, and down_proj as parameter
6379
(
6480
LoraConfig,
@@ -314,38 +330,75 @@ def test_targeting_module_and_targeting_param_equivalent(self):
314330
# LoRA outputs should be the same
315331
assert torch.allclose(out_lora_0, out_lora_1, atol=atol, rtol=rtol)
316332

317-
def test_target_multiple_parameters_on_same_module(self):
318-
# for now, it is not supported to target multiple parameters from the same module with the same adapter,
319-
# however, it is possible to target multiple parameters from same module with different adapters
333+
def test_target_multiple_parameters_on_same_module(self, monkeypatch):
334+
# test that if we target multiple nn.Parameters on the same module, all of them are being used during the
335+
# forward pass
320336
torch.manual_seed(0)
321-
model_id = "hf-internal-testing/tiny-random-LlamaForCausalLM"
337+
model_id = "trl-internal-testing/tiny-Llama4ForCausalLM"
322338
with hub_online_once(model_id):
323-
model = AutoModelForCausalLM.from_pretrained(model_id)
324339
x = torch.arange(10).view(2, 5)
325-
with torch.inference_mode():
326-
out_base = model(x, output_hidden_states=True).hidden_states[-1]
340+
model = MyAutoModelForCausalLM.from_pretrained(model_id)
341+
shape_gate_up_proj = model.model.layers[0].feed_forward.experts.gate_up_proj.shape
342+
shape_down_proj = model.model.layers[0].feed_forward.experts.down_proj.shape
343+
num_layers = len(model.model.layers)
327344

328-
# targeting gate_up_proj
329-
config0 = LoraConfig(target_parameters=["feed_forward.experts.gate_up_proj"], init_lora_weights=False)
330-
model = get_peft_model(model, config0)
331-
with torch.inference_mode():
332-
out_lora_0 = model(x, output_hidden_states=True).hidden_states[-1]
333-
atol, rtol = 1e-6, 1e-6
334-
assert not torch.allclose(out_base, out_lora_0, atol=atol, rtol=rtol)
345+
target_parameters = ["feed_forward.experts.gate_up_proj", "feed_forward.experts.down_proj"]
346+
num_params = len(target_parameters)
347+
config = LoraConfig(target_parameters=target_parameters, init_lora_weights=False)
348+
model = get_peft_model(model, config)
335349

336-
# targeting down_proj
337-
config1 = LoraConfig(target_parameters=["feed_forward.experts.down_proj"], init_lora_weights=False)
338-
model.add_adapter("other", config1)
339-
model.set_adapter("other")
340-
with torch.inference_mode():
341-
out_lora_1 = model(x, output_hidden_states=True).hidden_states[-1]
342-
assert not torch.allclose(out_base, out_lora_1, atol=atol, rtol=rtol)
343-
assert not torch.allclose(out_lora_0, out_lora_1, atol=atol, rtol=rtol)
350+
# CHECK FORWARD CALLS
351+
352+
# log the weights seen during the forward call
353+
weights = []
354+
355+
def mock_forward(self, W):
356+
weights.append(W)
357+
return orig_forward(self, W)
358+
359+
from peft.tuners.lora.layer import _LoraParameterProxy
360+
361+
orig_forward = _LoraParameterProxy.forward
362+
monkeypatch.setattr(_LoraParameterProxy, "forward", mock_forward)
344363

345-
# targeting both gate_up_proj and down_proj
346-
model.base_model.set_adapter(["default", "other"])
364+
num_steps = 3
347365
with torch.inference_mode():
348-
out_lora_01 = model(x, output_hidden_states=True).hidden_states[-1]
349-
assert not torch.allclose(out_base, out_lora_01, atol=atol, rtol=rtol)
350-
assert not torch.allclose(out_lora_0, out_lora_01, atol=atol, rtol=rtol)
351-
assert not torch.allclose(out_lora_1, out_lora_01, atol=atol, rtol=rtol)
366+
for _ in range(num_steps):
367+
out_base = model(x, output_hidden_states=True).hidden_states[-1]
368+
369+
actual_call_count = len(weights)
370+
# Note: We call forward twice per step, once to create the parametrization and once for the actual forward
371+
# step. This may be a bit wasteful but it's not clear how to prevent this and overall is probably negligible
372+
num_forward_per_step = 2
373+
expected_call_count = num_steps * num_layers * num_params * num_forward_per_step
374+
assert actual_call_count == expected_call_count
375+
376+
actual_shapes = {W.shape for W in weights}
377+
expected_shapes = {shape_gate_up_proj, shape_down_proj}
378+
assert actual_shapes == expected_shapes
379+
380+
# CHECK WEIGHT UPDATES
381+
382+
lora_weights_before = {
383+
k: v.clone() for k, v in model.named_parameters() if "lora_A.default" in k or "lora_B.default" in k
384+
}
385+
print(lora_weights_before)
386+
# sanity check:
387+
assert len(lora_weights_before) == 2 * num_layers * num_params
388+
# train
389+
optim = torch.optim.SGD(model.parameters(), lr=0.01)
390+
for _ in range(10):
391+
optim.zero_grad()
392+
out = model(x)
393+
loss = out.logits.sum()
394+
loss.backward()
395+
optim.step()
396+
397+
print(lora_weights_before)
398+
lora_weights_after = {
399+
k: v for k, v in model.named_parameters() if "lora_A.default" in k or "lora_B.default" in k
400+
}
401+
assert lora_weights_before.keys() == lora_weights_after.keys()
402+
atol, rtol = 0.1, 0.1
403+
for key in lora_weights_before.keys():
404+
assert not torch.allclose(lora_weights_before[key], lora_weights_after[key], atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)