Skip to content

Commit 7995572

Browse files
Auto-tagging of PEFT models (#2599)
Features like inference need correctly set tags on the repo / the model card in order to be available. Also the Hub uses tags to index the models and make them searchable. With this change PEFT tags models automatically as lora if they happen to be trained with LoRA, the base model and a custom `peft:method:<the method>` tag. * Base model tags were never supported, they are now Before PEFT simply ignored tags provided by the base model. Now the base model tags are added to the PEFT-specific model tags. * Tag 'tranformers' and add pipeline tag if possible We remove the `peft:method:*` tag because this change needs more discussion and is partially unrelated to this change. It is replaced by the necessary `transformers` tag if the model is based on transformers. We're also trying to resolve the pipeline tag automatically if it isn't set. While there is the `transformers.pipelines.base.SUPPORTED_PEFT_TASKS` mapping it is not sufficient to resolve the pipeline tag automatically since it is not a 1:1 mapping. Only the causal LM case is a unique mapping. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
1 parent 180777e commit 7995572

File tree

3 files changed

+146
-1
lines changed

3 files changed

+146
-1
lines changed

src/peft/peft_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,26 @@ def base_model_torch_dtype(self):
14381438
def active_peft_config(self):
14391439
return self.peft_config[self.active_adapter]
14401440

1441+
def _get_peft_specific_model_tags(self):
1442+
"""Derive tags for the model card from the adapter's config. For example, setting the
1443+
base model is important for enabling support for HF inference providers but it also makes models more
1444+
searchable on the HF hub.
1445+
"""
1446+
peft_method = self.active_peft_config.peft_type.value
1447+
1448+
tags = []
1449+
1450+
if hasattr(self.base_model, "model") and isinstance(self.base_model.model, transformers.PreTrainedModel):
1451+
tags.append("transformers")
1452+
1453+
if peft_method == "LORA":
1454+
tags.append("lora")
1455+
1456+
if hasattr(self.base_model, "name_or_path"):
1457+
tags.append(f"base_model:adapter:{self.base_model.name_or_path}")
1458+
1459+
return tags
1460+
14411461
def create_or_update_model_card(self, output_dir: str):
14421462
"""
14431463
Updates or create model card to include information about peft:
@@ -1453,6 +1473,20 @@ def create_or_update_model_card(self, output_dir: str):
14531473

14541474
card.data["library_name"] = "peft"
14551475

1476+
tags = set()
1477+
base_model = self.get_base_model()
1478+
if hasattr(base_model, "model_tags"):
1479+
tags = tags.union(base_model.model_tags or [])
1480+
1481+
tags = tags.union(self._get_peft_specific_model_tags())
1482+
if tags:
1483+
card.data["tags"] = sorted(tags)
1484+
1485+
# One of the rare moments where we can select the pipeline tag with certainty, so let's do that.
1486+
# Makes it easier to deploy an adapter with auto inference since the user doesn't have to add any tags.
1487+
if not card.data.pipeline_tag and isinstance(self, PeftModelForCausalLM):
1488+
card.data.pipeline_tag = "text-generation"
1489+
14561490
model_config = BaseTuner.get_model_config(self)
14571491
model_config = None if model_config == DUMMY_MODEL_CONFIG else model_config
14581492
if model_config is not None and "_name_or_path" in model_config:

tests/test_hub_features.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import pytest
1717
import torch
18+
from huggingface_hub import ModelCard
1819
from transformers import AutoModelForCausalLM
1920

20-
from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig, PeftModel, get_peft_model
21+
from peft import AutoPeftModelForCausalLM, BoneConfig, LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model
22+
23+
from .testing_common import hub_online_once
2124

2225

2326
PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")]
@@ -112,3 +115,106 @@ def test_load_different_peft_and_base_model_revision(self, tmp_path):
112115

113116
assert peft_model.peft_config["default"].base_model_name_or_path == base_model_id
114117
assert peft_model.peft_config["default"].revision == base_model_revision
118+
119+
120+
class TestModelCard:
121+
@pytest.mark.parametrize(
122+
"model_id, peft_config, tags, excluded_tags, pipeline_tag",
123+
[
124+
(
125+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM",
126+
LoraConfig(),
127+
["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM", "lora"],
128+
[],
129+
None,
130+
),
131+
(
132+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM",
133+
BoneConfig(),
134+
["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM"],
135+
["lora"],
136+
None,
137+
),
138+
(
139+
"hf-internal-testing/tiny-random-BartForConditionalGeneration",
140+
LoraConfig(),
141+
[
142+
"transformers",
143+
"base_model:adapter:hf-internal-testing/tiny-random-BartForConditionalGeneration",
144+
"lora",
145+
],
146+
[],
147+
None,
148+
),
149+
(
150+
"hf-internal-testing/tiny-random-Gemma3ForCausalLM",
151+
LoraConfig(task_type=TaskType.CAUSAL_LM),
152+
["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM", "lora"],
153+
[],
154+
"text-generation",
155+
),
156+
],
157+
)
158+
@pytest.mark.parametrize(
159+
"pre_tags",
160+
[
161+
["tag1", "tag2"],
162+
[],
163+
],
164+
)
165+
def test_model_card_has_expected_tags(
166+
self, model_id, peft_config, tags, excluded_tags, pipeline_tag, pre_tags, tmp_path
167+
):
168+
"""Make sure that PEFT sets the tags in the model card automatically and correctly.
169+
This is important so that a) the models are searchable on the Hub and also 2) some features depend on it to
170+
decide how to deal with them (e.g., inference).
171+
172+
Makes sure that the base model tags are still present (if there are any).
173+
"""
174+
with hub_online_once(model_id):
175+
base_model = AutoModelForCausalLM.from_pretrained(model_id)
176+
177+
if pre_tags:
178+
base_model.add_model_tags(pre_tags)
179+
180+
peft_model = get_peft_model(base_model, peft_config)
181+
save_path = tmp_path / "adapter"
182+
183+
peft_model.save_pretrained(save_path)
184+
185+
model_card = ModelCard.load(save_path / "README.md")
186+
assert set(tags).issubset(set(model_card.data.tags))
187+
188+
if excluded_tags:
189+
assert set(excluded_tags).isdisjoint(set(model_card.data.tags))
190+
191+
if pre_tags:
192+
assert set(pre_tags).issubset(set(model_card.data.tags))
193+
194+
if pipeline_tag:
195+
assert model_card.data.pipeline_tag == pipeline_tag
196+
197+
@pytest.fixture
198+
def custom_model_cls(self):
199+
class MyNet(torch.nn.Module):
200+
def __init__(self):
201+
super().__init__()
202+
self.l1 = torch.nn.Linear(10, 20)
203+
self.l2 = torch.nn.Linear(20, 1)
204+
205+
def forward(self, X):
206+
return self.l2(self.l1(X))
207+
208+
return MyNet
209+
210+
def test_custom_models_dont_have_transformers_tag(self, custom_model_cls, tmp_path):
211+
base_model = custom_model_cls()
212+
peft_config = LoraConfig(target_modules="all-linear")
213+
peft_model = get_peft_model(base_model, peft_config)
214+
215+
peft_model.save_pretrained(tmp_path)
216+
217+
model_card = ModelCard.load(tmp_path / "README.md")
218+
219+
assert model_card.data.tags is not None
220+
assert "transformers" not in model_card.data.tags

tests/testing_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,11 @@ def check_modelcard(self, tmp_dirname, model):
266266
else: # a custom model
267267
assert "base_model" not in dct
268268

269+
# The Hub expects the lora tag to be set for PEFT LoRA models since they
270+
# have explicit support for things like inference.
271+
if model.active_peft_config.peft_type.value == "LORA":
272+
assert "lora" in dct["tags"]
273+
269274
def check_config_json(self, tmp_dirname, model):
270275
# check the generated config.json
271276
filename = os.path.join(tmp_dirname, "adapter_config.json")

0 commit comments

Comments
 (0)