|
15 | 15 |
|
16 | 16 | import pytest
|
17 | 17 | import torch
|
| 18 | +from huggingface_hub import ModelCard |
18 | 19 | from transformers import AutoModelForCausalLM
|
19 | 20 |
|
20 |
| -from peft import AutoPeftModelForCausalLM, LoraConfig, PeftConfig, PeftModel, get_peft_model |
| 21 | +from peft import AutoPeftModelForCausalLM, BoneConfig, LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model |
| 22 | + |
| 23 | +from .testing_common import hub_online_once |
21 | 24 |
|
22 | 25 |
|
23 | 26 | PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")]
|
@@ -112,3 +115,106 @@ def test_load_different_peft_and_base_model_revision(self, tmp_path):
|
112 | 115 |
|
113 | 116 | assert peft_model.peft_config["default"].base_model_name_or_path == base_model_id
|
114 | 117 | assert peft_model.peft_config["default"].revision == base_model_revision
|
| 118 | + |
| 119 | + |
| 120 | +class TestModelCard: |
| 121 | + @pytest.mark.parametrize( |
| 122 | + "model_id, peft_config, tags, excluded_tags, pipeline_tag", |
| 123 | + [ |
| 124 | + ( |
| 125 | + "hf-internal-testing/tiny-random-Gemma3ForCausalLM", |
| 126 | + LoraConfig(), |
| 127 | + ["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM", "lora"], |
| 128 | + [], |
| 129 | + None, |
| 130 | + ), |
| 131 | + ( |
| 132 | + "hf-internal-testing/tiny-random-Gemma3ForCausalLM", |
| 133 | + BoneConfig(), |
| 134 | + ["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM"], |
| 135 | + ["lora"], |
| 136 | + None, |
| 137 | + ), |
| 138 | + ( |
| 139 | + "hf-internal-testing/tiny-random-BartForConditionalGeneration", |
| 140 | + LoraConfig(), |
| 141 | + [ |
| 142 | + "transformers", |
| 143 | + "base_model:adapter:hf-internal-testing/tiny-random-BartForConditionalGeneration", |
| 144 | + "lora", |
| 145 | + ], |
| 146 | + [], |
| 147 | + None, |
| 148 | + ), |
| 149 | + ( |
| 150 | + "hf-internal-testing/tiny-random-Gemma3ForCausalLM", |
| 151 | + LoraConfig(task_type=TaskType.CAUSAL_LM), |
| 152 | + ["transformers", "base_model:adapter:hf-internal-testing/tiny-random-Gemma3ForCausalLM", "lora"], |
| 153 | + [], |
| 154 | + "text-generation", |
| 155 | + ), |
| 156 | + ], |
| 157 | + ) |
| 158 | + @pytest.mark.parametrize( |
| 159 | + "pre_tags", |
| 160 | + [ |
| 161 | + ["tag1", "tag2"], |
| 162 | + [], |
| 163 | + ], |
| 164 | + ) |
| 165 | + def test_model_card_has_expected_tags( |
| 166 | + self, model_id, peft_config, tags, excluded_tags, pipeline_tag, pre_tags, tmp_path |
| 167 | + ): |
| 168 | + """Make sure that PEFT sets the tags in the model card automatically and correctly. |
| 169 | + This is important so that a) the models are searchable on the Hub and also 2) some features depend on it to |
| 170 | + decide how to deal with them (e.g., inference). |
| 171 | +
|
| 172 | + Makes sure that the base model tags are still present (if there are any). |
| 173 | + """ |
| 174 | + with hub_online_once(model_id): |
| 175 | + base_model = AutoModelForCausalLM.from_pretrained(model_id) |
| 176 | + |
| 177 | + if pre_tags: |
| 178 | + base_model.add_model_tags(pre_tags) |
| 179 | + |
| 180 | + peft_model = get_peft_model(base_model, peft_config) |
| 181 | + save_path = tmp_path / "adapter" |
| 182 | + |
| 183 | + peft_model.save_pretrained(save_path) |
| 184 | + |
| 185 | + model_card = ModelCard.load(save_path / "README.md") |
| 186 | + assert set(tags).issubset(set(model_card.data.tags)) |
| 187 | + |
| 188 | + if excluded_tags: |
| 189 | + assert set(excluded_tags).isdisjoint(set(model_card.data.tags)) |
| 190 | + |
| 191 | + if pre_tags: |
| 192 | + assert set(pre_tags).issubset(set(model_card.data.tags)) |
| 193 | + |
| 194 | + if pipeline_tag: |
| 195 | + assert model_card.data.pipeline_tag == pipeline_tag |
| 196 | + |
| 197 | + @pytest.fixture |
| 198 | + def custom_model_cls(self): |
| 199 | + class MyNet(torch.nn.Module): |
| 200 | + def __init__(self): |
| 201 | + super().__init__() |
| 202 | + self.l1 = torch.nn.Linear(10, 20) |
| 203 | + self.l2 = torch.nn.Linear(20, 1) |
| 204 | + |
| 205 | + def forward(self, X): |
| 206 | + return self.l2(self.l1(X)) |
| 207 | + |
| 208 | + return MyNet |
| 209 | + |
| 210 | + def test_custom_models_dont_have_transformers_tag(self, custom_model_cls, tmp_path): |
| 211 | + base_model = custom_model_cls() |
| 212 | + peft_config = LoraConfig(target_modules="all-linear") |
| 213 | + peft_model = get_peft_model(base_model, peft_config) |
| 214 | + |
| 215 | + peft_model.save_pretrained(tmp_path) |
| 216 | + |
| 217 | + model_card = ModelCard.load(tmp_path / "README.md") |
| 218 | + |
| 219 | + assert model_card.data.tags is not None |
| 220 | + assert "transformers" not in model_card.data.tags |
0 commit comments