Skip to content

Commit 53c25fe

Browse files
Release: 0.17.1 changes (#2739)
* FIX Multiple issues with target_parameters (#2710) * Bump version to 0.17.1
1 parent 48f6493 commit 53c25fe

File tree

10 files changed

+347
-109
lines changed

10 files changed

+347
-109
lines changed

docs/source/developer_guides/lora.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,10 @@ The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get
276276
277277
Generally, you should use `target_modules` to target the module (e.g. `nn.Linear`). However, in some circumstances, this is not possible. E.g., in many mixture of expert (MoE) layers in HF Transformers, instead of using `nn.Linear`, an `nn.Parameter` is used. PEFT normally overwrites the `forward` method for LoRA, but for `nn.Parameter`, there is none. Therefore, to apply LoRA to that parameter, it needs to be targeted with `target_parameters`. As an example, for [Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164), you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`.
278278

279-
At the moment, this argument allows to target 2-dim or 3-dim `nn.Parameter`s. It is assumed that in the case of a 3-dim parameter, the 0th dimension is the expert dimension.
279+
#### Caveats
280+
281+
- At the moment, this argument allows to target 2-dim or 3-dim `nn.Parameter`s. It is assumed that in the case of a 3-dim parameter, the 0th dimension is the expert dimension.
282+
- It is currently not possible to add multiple LoRA adapters (via `model.add_adapter` or `model.load_adapter`) that use `target_parameters` at the same time.
280283

281284
## Optimizers
282285

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from setuptools import find_packages, setup
1616

1717

18-
VERSION = "0.17.0"
18+
VERSION = "0.17.1"
1919

2020
extras = {}
2121
extras["quality"] = [

src/peft/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.17.0"
15+
__version__ = "0.17.1"
1616

1717
from .auto import (
1818
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,

src/peft/tuners/lora/layer.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2015,14 +2015,28 @@ def _remove_parametrizations(self):
20152015
"Something went wrong, please report this issue on PEFT: https://github.com/huggingface/peft/issues"
20162016
)
20172017

2018-
if len(base_layer.parametrizations[parameter_name]) == 1:
2018+
param_list = base_layer.parametrizations[parameter_name]
2019+
if len(param_list) == 1:
20192020
# last parametrization, we can safely remove it completely
20202021
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
2021-
else:
2022-
# TODO: If there are multiple parametrizations for the same parameter_name, we currently remove all of them,
2023-
# which is not desired. Unfortunately, PyTorch does not support this directly, so we need to take care.
2024-
# For now, remove all parametrizations.
2025-
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
2022+
return
2023+
2024+
# If there are multiple parametrizations for the same parameter_name, we only want to remove the LoRA proxy.
2025+
# Unfortunately, PyTorch does not support this directly, so we need to take care of it manually. To achieve
2026+
# this, we check the ParameterList from the back until we find the _LoraParameterProxy instance and then remove
2027+
# it.
2028+
reversed_indices = reversed(range(len(param_list)))
2029+
for i in reversed_indices:
2030+
module = param_list[i]
2031+
if isinstance(module, _LoraParameterProxy):
2032+
del param_list[i]
2033+
break
2034+
else: # no break encountered
2035+
# this should not happen, but raising an error is probably not necessary
2036+
warnings.warn(
2037+
f"Could not find any LoRA parametrization on {self}, please open an issue on "
2038+
"https://github.com/huggingface/peft/issues and report this warning."
2039+
)
20262040

20272041
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
20282042
# same as lora.Linear.merge but not hard-coding base_layer.weight and without special cases like variants removed
@@ -2106,6 +2120,10 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
21062120

21072121
def __repr__(self) -> str:
21082122
rep = super().__repr__()
2123+
idx = rep.find("(") + 1
2124+
# insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same
2125+
# module are being targeted
2126+
rep = f"{rep[:idx]}\n parameter_name='{self.parameter_name}',{rep[idx:]}"
21092127
return "lora." + rep
21102128

21112129

src/peft/tuners/lora/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,18 @@ def _create_and_replace(
185185
if current_key is None:
186186
raise ValueError("Current Key shouldn't be `None`")
187187

188+
if lora_config.target_parameters:
189+
# Right now, unfortunately, we don't support multiple adapters with target_parameters on the same model.
190+
other_configs_use_target_params = any(
191+
conf.target_parameters for key, conf in self.peft_config.items() if key != adapter_name
192+
)
193+
if other_configs_use_target_params:
194+
raise ValueError(
195+
f"Adding a LoRA config with `target_parameters={lora_config.target_parameters}` but there are "
196+
"already other LoRA adapters on this model that use `target_parameters`. At the moment, only "
197+
"one LoRA adapter per model with `target_parameters` is allowed."
198+
)
199+
188200
# Regexp matching - Find key which matches current target_name in patterns provided
189201
r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key)
190202
alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key)

src/peft/tuners/tuners_utils.py

Lines changed: 68 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -722,43 +722,77 @@ def inject_adapter(
722722
def _inject_parameters(
723723
self, peft_config: PeftConfig, model: nn.Module, adapter_name: str, low_cpu_mem_usage: bool
724724
) -> None:
725-
# TODO very simple matching, might not cover all use cases
726-
target_names = set(peft_config.target_parameters)
727-
for module_name, module in model.named_modules():
728-
for param_name, param in module.named_parameters(recurse=False):
729-
# It is possible that the layer is already a PEFT layer and needs updating with a new adapter. In this
730-
# case, the name of parameter would be something like `model.layers.0.experts.base_layer.weight`, i.e.
731-
# there is a "base_layer" inserted in the name. We need to remove that, otherwise we won't be able to
732-
# match correctly (in this case, "experts.weight" would not match).
733-
prefix, _, suffix = module_name.rpartition(".base_layer")
725+
"""Inject layers based on peft_config.target_modules"""
726+
727+
def strip_base_layer_from_name(module_name):
728+
# It is possible that the layer is already a PEFT layer and needs updating with a new adapter. In this case,
729+
# the name of parameter would be something like `model.layers.0.experts.base_layer.weight`, i.e. there is a
730+
# "base_layer" inserted in the name. We need to remove that, otherwise we won't be able to match correctly
731+
# (in this case, "experts.weight" would not match).
732+
name = ".base_layer"
733+
while name in module_name:
734+
prefix, _, suffix = module_name.rpartition(name)
734735
module_name = prefix + suffix
735-
key = f"{module_name}.{param_name}"
736-
# we're interested in finding the "lowest" module that contains the parameter, hence recurse=False
737-
if (key in target_names) or any(key.endswith(f".{target_key}") for target_key in target_names):
738-
self.targeted_parameter_names.append(key)
736+
return module_name
737+
738+
def create_and_replace_param(module_name, key, param_name):
739+
# helper function to avoid duplication
740+
parent, target, target_name = _get_submodules(model, module_name)
741+
unwrapped_module_name = strip_base_layer_from_name(module_name)
742+
unwrapped_module = model.get_submodule(unwrapped_module_name)
743+
# use the class name for checking to avoid circular import
744+
if isinstance(unwrapped_module, BaseTunerLayer) and unwrapped_module.__class__.__name__ != "ParamWrapper":
745+
raise ValueError(
746+
f"Trying to wrap an `nn.Parameter` of layer '{unwrapped_module_name}' of type "
747+
f"{type(target).__name__}, which is not a valid target. Make sure that this layer is not "
748+
"also targeted with `target_modules`. For some models, PEFT will do this automatically, "
749+
"try setting `target_modules=[]` to prevent it."
750+
)
739751

740-
parent, target, target_name = _get_submodules(model, module_name)
741-
# use the class name for checking to avoid circular import
742-
if isinstance(target, BaseTunerLayer) and target.__class__.__name__ != "ParamWrapper":
743-
raise ValueError(
744-
f"Trying to wrap an `nn.Parameter` of layer '{target_name}' of type "
745-
f"{type(target).__name__}, which is not a valid target. Make sure that this layer is not "
746-
"also targeted with `target_modules`. For some models, PEFT will do this automatically, "
747-
"try setting `target_modules=[]` to prevent it."
748-
)
752+
self._check_target_module_compatiblity(peft_config, model, target_name)
753+
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
754+
with ctx():
755+
self._create_and_replace(
756+
peft_config,
757+
adapter_name,
758+
target,
759+
target_name,
760+
parent,
761+
current_key=key,
762+
parameter_name=param_name.rpartition(".")[-1],
763+
)
749764

750-
self._check_target_module_compatiblity(peft_config, model, target_name)
751-
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
752-
with ctx():
753-
self._create_and_replace(
754-
peft_config,
755-
adapter_name,
756-
target,
757-
target_name,
758-
parent,
759-
current_key=key,
760-
parameter_name=param_name.rpartition(".")[-1],
761-
)
765+
# TODO very simple matching, might not cover all use cases
766+
unsorted_target_names = set(peft_config.target_parameters)
767+
# As the order of matching can influence the nesting of multiple params on the same module, ensure determinism
768+
# by sorting.
769+
target_names = sorted(unsorted_target_names)
770+
for module_name, module in model.named_modules():
771+
if hasattr(module, "parametrizations"):
772+
# Deal with the case that the parameter is already parametrized. The issue is that we would not be able
773+
# to match `f"{module_name}.{param_name}"`, as the parameter is now something like
774+
# `module.parametrization.weight`.
775+
for key in target_names:
776+
target_module_name, _, param_name = key.rpartition(".")
777+
if target_module_name != module_name:
778+
continue
779+
if getattr(module, param_name, None) is None:
780+
continue
781+
create_and_replace_param(module_name, key, param_name)
782+
self.targeted_parameter_names.append(key)
783+
else:
784+
# Standard case: the parameter is not already parametrized. Note, however, that the model could already
785+
# be nested with lora.ParamWrapper, as this is how we allow targeting multiple Parameters on the same
786+
# module.
787+
unwrapped_module_name = strip_base_layer_from_name(module_name)
788+
# we're interested in finding the "lowest" module that contains the parameter, hence recurse=False
789+
for param_name, param in module.named_parameters(recurse=False):
790+
key = f"{unwrapped_module_name}.{param_name}"
791+
if (key in target_names) or any(key.endswith(f".{target_key}") for target_key in target_names):
792+
# Note: We use the unwrapped_module_name to check if the key matches, but we use the module_name for
793+
# replacement, since we want to replace the wrapped module.
794+
create_and_replace_param(module_name, key, param_name)
795+
self.targeted_parameter_names.append(key)
762796

763797
def merge_adapter(self, adapter_names: Optional[list[str]] = None, safe_merge: bool = False) -> None:
764798
"""

tests/test_custom_models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,11 @@
936936
}
937937

938938

939+
def _skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs):
940+
if (config_cls == LoraConfig) and config_kwargs.get("target_parameters"):
941+
pytest.skip("LoRA with multiple adapters with target_parameters is not supported")
942+
943+
939944
class MLP(nn.Module):
940945
def __init__(self, bias=True):
941946
super().__init__()
@@ -1389,6 +1394,7 @@ def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kw
13891394

13901395
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
13911396
def test_load_model_low_cpu_mem_usage(self, test_name, model_id, config_cls, config_kwargs):
1397+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
13921398
self._test_load_model_low_cpu_mem_usage(model_id, config_cls, config_kwargs)
13931399

13941400
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@@ -1397,6 +1403,7 @@ def test_from_pretrained_config_construction(self, test_name, model_id, config_c
13971403

13981404
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
13991405
def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kwargs):
1406+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
14001407
self._test_load_multiple_adapters(model_id, config_cls, config_kwargs)
14011408

14021409
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@@ -1995,6 +2002,8 @@ def run_with_disable(config_kwargs, bias):
19952002

19962003
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
19972004
def test_active_adapter(self, test_name, model_id, config_cls, config_kwargs):
2005+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
2006+
19982007
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
19992008
config = config_cls(
20002009
base_model_name_or_path=model_id,
@@ -2085,10 +2094,12 @@ def test_disable_adapters_exiting_context_irregular_state(self, test_name, model
20852094

20862095
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
20872096
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
2097+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
20882098
self._test_delete_adapter(model_id, config_cls, config_kwargs)
20892099

20902100
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
20912101
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
2102+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
20922103
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
20932104

20942105
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@@ -2786,6 +2797,19 @@ def test_repr_lora_conv2d(self):
27862797
assert "lora_B" in print_output
27872798
assert "default" in print_output
27882799

2800+
def test_repr_lora_paramwrapper(self):
2801+
config = LoraConfig(target_parameters=["lin0.weight"])
2802+
model = get_peft_model(MLP(), config)
2803+
print_output = repr(model.model.lin0)
2804+
assert print_output.startswith("lora.ParamWrapper")
2805+
# important: targeted parameter should be contained:
2806+
assert "parameter_name='weight'" in print_output
2807+
assert "in_features=10" in print_output
2808+
assert "out_features=20" in print_output
2809+
assert "lora_A" in print_output
2810+
assert "lora_B" in print_output
2811+
assert "default" in print_output
2812+
27892813

27902814
class TestMultipleActiveAdapters:
27912815
"""
@@ -2820,6 +2844,8 @@ def resolve_model_cls(self, tuner_method):
28202844
def test_multiple_active_adapters_forward(
28212845
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
28222846
):
2847+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
2848+
28232849
torch.manual_seed(0)
28242850

28252851
model = self.resolve_model_cls(tuner_method)
@@ -2878,6 +2904,8 @@ def test_multiple_active_adapters_forward(
28782904
def test_multiple_active_adapters_merge_and_unmerge(
28792905
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
28802906
):
2907+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
2908+
28812909
torch.manual_seed(0)
28822910

28832911
model = self.resolve_model_cls(tuner_method)
@@ -2911,6 +2939,8 @@ def test_multiple_active_adapters_merge_and_unmerge(
29112939
"test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2", MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES
29122940
)
29132941
def test_merge_layers_multi(self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2):
2942+
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
2943+
29142944
torch.manual_seed(0)
29152945

29162946
model = self.resolve_model_cls(tuner_method)

tests/test_initialization.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ def __init__(self):
14061406
self.linear = nn.Linear(10, 10)
14071407

14081408
base_model = MyModule()
1409-
config = LoraConfig(target_modules=["linear"], target_parameters=["weight"])
1409+
config = LoraConfig(target_modules=["linear"], target_parameters=["linear.weight"])
14101410
msg = "Trying to wrap an `nn.Parameter` of layer 'linear' of type Linear, which is not a valid target."
14111411
with pytest.raises(ValueError, match=msg):
14121412
get_peft_model(base_model, config)
@@ -1445,6 +1445,26 @@ def test_valid_target_modules_invalid_target_parameters_warns(self):
14451445
with pytest.warns(RuntimeWarning, match=msg):
14461446
get_peft_model(model, config)
14471447

1448+
def test_adding_multiple_adapters_with_target_parameters_raises(self):
1449+
model = self.get_model()
1450+
config = LoraConfig(target_modules=[], target_parameters=["linear.weight"])
1451+
model = get_peft_model(model, config)
1452+
msg = re.escape("only one LoRA adapter per model with `target_parameters` is allowed")
1453+
with pytest.raises(ValueError, match=msg):
1454+
model.add_adapter(adapter_name="other", peft_config=config)
1455+
1456+
def test_loading_loading_adapters_with_target_parameters_raises(self, tmp_path):
1457+
model = self.get_model()
1458+
config = LoraConfig(target_modules=[], target_parameters=["linear.weight"])
1459+
model = get_peft_model(model, config)
1460+
model.save_pretrained(tmp_path)
1461+
1462+
model = self.get_model()
1463+
model = PeftModel.from_pretrained(model, tmp_path)
1464+
msg = re.escape("only one LoRA adapter per model with `target_parameters` is allowed")
1465+
with pytest.raises(ValueError, match=msg):
1466+
model.load_adapter(tmp_path, adapter_name="other")
1467+
14481468

14491469
class TestLokrInitialization:
14501470
torch_device = infer_device()

0 commit comments

Comments
 (0)