Skip to content

Commit 337be05

Browse files
ENH: Adapter injection based on state_dict (#2637)
Make it possible to inject the PEFT adapters based on a state_dict instead of the PEFT config. See huggingface/diffusers#11874 for context. Description Right now, when creating a PEFT adapter like LoRA, the adapter layers are injected based on the PEFT config, most notably the entries in `target_modules`, but other arguments also play into this. Generally, this is a good approach, but it breaks down in some situations. For instance, in diffusers, we often have the situation that the checkpoint was created without PEFT/diffusers, thus there is no PEFT config, only the `state_dict`. To load these checkpoints in diffusers, the current approach is to reverse-engineer a valid PEFT config based on the keys in the `state_dict`. Unfortunately, this is error prone. Moreover, not every combination of `state_dict` keys can be easily expressed in a PEFT config through a combination of `target_modules`, `exclude_modules`, etc. Yes, in theory everything can be expressed by passing `target_module=<regex_pattern>`, but reverse-engineering such a regex correctly and efficiently is very hard (and thus currently not done). This PR implements a completely different approach to inject adapters. Instead of relying on the PEFT config to determine which layers to target, it takes the `state_dict` directly as the source of truth. This should allow to exactly match what is desired. Implementation details I took care to implement this change in a way that if no `state_dict` is passed, the exact same code path as previously is taken. The risk of breaking anything should thus be minimized. Technically, it is not necessary to pass the `state_dict`, we are only interested in the keys. I still called the argument `state_dict`, since that is typically what we have at this point, but this can be easily changed. I thought it might be a good idea, if the `state_dict` is used, to still check what modules would have been targeted if we had used the PEFT config. Then, the results are compared and a warning is given if they differ. This allows the user to see if the PEFT config is not correctly specified. While running some diffusers tests, I never encountered this warning, which is good. However, if we plan, for instance, to get rid of all the reverse engineering of the PEFT config in diffusers, it would make more sense to not give this warning. Caveats When the original LoRA model was using `target_parameters`, injecting from `state_dict` will not work correctly. The problem is that the `state_dict` looks the same, whether the module or a parameter was targeted. Therefore, we cannot correctly determine the user's intent. For now, what I decided to do is: 1. Always assume that `target_modules` is meant, as it's the far more common occurrence. 2. When we detect `target_parameters` while using `state_dict` for injection, we raise an error. 3. If we don't detect this, injection might just slip through, resulting in modules being targeted (if they are valid modules) instead of parameters. 4. Document that these two features don't work together. I think overall, this is not too concerning, as both features are rather niche and thus unlikely to be used in conjunction. Related changes While working on this PR, I made a couple of related, though not strictly necessary, changes: - Refactor tests in `test_low_level_api.py` to use pytest instead of unittest - Add default target modules for LoHa and LoKr (just copying LoRA) - Most PEFT method's model classes like `LoraModel` had an `__init__` that effectively just called `super()` with the same arguments. I removed these `__init__` methods.
1 parent bb4fb50 commit 337be05

File tree

26 files changed

+497
-84
lines changed

26 files changed

+497
-84
lines changed

docs/source/developer_guides/low_level_api.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rendered properly in your Markdown viewer.
1616

1717
# Adapter injection
1818

19-
With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. Currently, PEFT supports injecting [LoRA](../conceptual_guides/adapter#low-rank-adaptation-lora), [AdaLoRA](../conceptual_guides/adapter#adaptive-low-rank-adaptation-adalora), and [IA3](../conceptual_guides/ia3) into models because for these adapters, inplace modification of the model is sufficient for finetuning it.
19+
With PEFT, you can inject trainable adapters into any `torch` module which allows you to use adapter methods without relying on the modeling classes in PEFT. This works for all adapters except for those based on prompt learning (e.g. prefix tuning or p-tuning).
2020

2121
Check the table below to see when you should inject adapters.
2222

@@ -87,6 +87,28 @@ DummyModel(
8787
)
8888
```
8989

90+
### Injection based on a `state_dict`
91+
92+
Sometimes, it is possible that there is a PEFT adapter checkpoint but the corresponding PEFT config is not known for whatever reason. To inject the PEFT layers for this checkpoint, you would usually have to reverse-engineer the corresponding PEFT config, most notably the `target_modules` argument, based on the `state_dict` from the checkpoint. This can be cumbersome and error prone. To avoid this, it is also possible to call [`inject_adapter_in_model`] and pass the loaded `state_dict` as an argument:
93+
94+
```python
95+
from safetensors.torch import load_file
96+
97+
model = ...
98+
state_dict = load_file(<path-to-safetensors-file>)
99+
lora_config = LoraConfig(...)
100+
model = inject_adapter_in_model(lora_config, model, state_dict=state_dict)
101+
```
102+
103+
In this case, PEFT will use the `state_dict` as reference for which layers to target instead of using the PEFT config. As a user, you don't have to set the exact `target_modules` of the PEFT config for this to work. However, you should still pass a PEFT config of the right type, in this example `LoraConfig`, you can leave the `target_modules` as `None`.
104+
105+
Be aware that this still only creates the uninitialized PEFT layers, the values from the `state_dict` are not used to populate the model weights. To populate the weights, proceed with calling [`set_peft_model_state_dict`] as described below.
106+
107+
⚠️ Note that if there is a mismatch between what is configured in the PEFT config and what is found in the `state_dict`, PEFT will warn you about this. You can ignore the warning if you know that the PEFT config is not correctly specified.
108+
109+
> [!WARNING]
110+
> If the original PEFT adapters was using `target_parameters` instead of `target_modules`, injecting from a `state_dict` will not work correctly. In this case, it is mandatory to use the correct PEFT config for injection.
111+
90112
## Saving the model
91113

92114
To only save the adapter, use the [`get_peft_model_state_dict`] function:

src/peft/mapping.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Any
17+
from typing import TYPE_CHECKING, Any, Optional
1818

1919
import torch
2020

@@ -45,7 +45,11 @@ def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
4545

4646

4747
def inject_adapter_in_model(
48-
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default", low_cpu_mem_usage: bool = False
48+
peft_config: PeftConfig,
49+
model: torch.nn.Module,
50+
adapter_name: str = "default",
51+
low_cpu_mem_usage: bool = False,
52+
state_dict: Optional[dict[str, torch.Tensor]] = None,
4953
) -> torch.nn.Module:
5054
r"""
5155
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
@@ -61,6 +65,11 @@ def inject_adapter_in_model(
6165
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
6266
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
6367
Create empty adapter weights on meta device. Useful to speed up the loading process.
68+
state_dict (`dict`, *optional*, defaults to `None`)
69+
If a state_dict is passed here, the adapters will be injected based on the entries of the state_dict. This
70+
can be useful when the exact `target_modules` of the PEFT method is unknown, for instance because the
71+
checkpoint was created without meta data. Note that the values from the state_dict are not used, only the
72+
keys are used to determine the correct layers that should be adapted.
6473
"""
6574
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
6675
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
@@ -73,6 +82,8 @@ def inject_adapter_in_model(
7382
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
7483

7584
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
76-
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
85+
peft_model = tuner_cls(
86+
model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage, state_dict=state_dict
87+
)
7788

7889
return peft_model.model

src/peft/tuners/adalora/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ class AdaLoraModel(LoraModel):
6565

6666
# Note: don't redefine prefix here, it should be inherited from LoraModel
6767

68-
def __init__(self, model, config, adapter_name):
69-
super().__init__(model, config, adapter_name)
68+
def __init__(self, model, config, adapter_name, **kwargs):
69+
super().__init__(model, config, adapter_name, **kwargs)
7070

7171
traininable_mode_counter = 0
7272
for config in self.peft_config.values():

src/peft/tuners/boft/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ class BOFTModel(BaseTuner):
7474

7575
prefix: str = "boft_"
7676

77-
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
78-
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
79-
8077
def _check_new_adapter_config(self, config: BOFTConfig) -> None:
8178
"""
8279
A helper method to check the config when a new adapter is being added.

src/peft/tuners/c3a/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ class C3AModel(BaseTuner):
5555

5656
prefix: str = "c3a_"
5757

58-
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
59-
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
60-
6158
def _check_new_adapter_config(self, config: C3AConfig) -> None:
6259
"""
6360
A helper method to check the config when a new adapter is being added.

src/peft/tuners/fourierft/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ class FourierFTModel(BaseTuner):
5858

5959
prefix: str = "fourierft_"
6060

61-
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
62-
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
63-
6461
def _check_new_adapter_config(self, config: FourierFTConfig) -> None:
6562
"""
6663
A helper method to check the config when a new adapter is being added.

src/peft/tuners/ia3/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ class IA3Model(BaseTuner):
7575

7676
prefix: str = "ia3_"
7777

78-
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False):
79-
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
80-
8178
@staticmethod
8279
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
8380
# avoid eager bnb import

src/peft/tuners/ln_tuning/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ class LNTuningModel(BaseTuner):
6565

6666
prefix: str = "ln_tuning_"
6767

68-
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
69-
# self.adapter_name = adapter_name
70-
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
71-
7268
def __getattr__(self, name: str):
7369
"""Forward missing attributes to the wrapped module."""
7470
try:

src/peft/tuners/loha/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919

2020
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
21+
from peft.utils import TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING
2122
from peft.utils.other import get_pattern_key
2223

2324
from .layer import Conv2d, Linear, LoHaLayer
@@ -110,3 +111,13 @@ def _create_and_replace(
110111
else:
111112
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
112113
self._replace_module(parent, target_name, new_module, target)
114+
115+
@staticmethod
116+
def _prepare_adapter_config(peft_config, model_config):
117+
if peft_config.target_modules is None:
118+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING:
119+
raise ValueError("Please specify `target_modules` in `peft_config`")
120+
peft_config.target_modules = set(
121+
TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING[model_config["model_type"]]
122+
)
123+
return peft_config

src/peft/tuners/lokr/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919

2020
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
21+
from peft.utils import TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING
2122
from peft.utils.other import get_pattern_key
2223

2324
from .layer import Conv2d, Linear, LoKrLayer
@@ -112,3 +113,13 @@ def _create_and_replace(
112113
else:
113114
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
114115
self._replace_module(parent, target_name, new_module, target)
116+
117+
@staticmethod
118+
def _prepare_adapter_config(peft_config, model_config):
119+
if peft_config.target_modules is None:
120+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING:
121+
raise ValueError("Please specify `target_modules` in `peft_config`")
122+
peft_config.target_modules = set(
123+
TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING[model_config["model_type"]]
124+
)
125+
return peft_config

0 commit comments

Comments
 (0)