Skip to content

Commit a9b8ec8

Browse files
authored
Merge pull request #262 from aluu317/test_config_utils
Add config_utils tests
2 parents 537215f + dc32eda commit a9b8ec8

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed

tests/utils/test_config_utils.py

+235
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# SPDX-License-Identifier: Apache-2.0
16+
# https://spdx.dev/learn/handling-license-info/
17+
18+
# Standard
19+
import base64
20+
import os
21+
import pickle
22+
23+
# Third Party
24+
from peft import LoraConfig, PromptTuningConfig
25+
import pytest
26+
27+
# First Party
28+
from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
29+
30+
# Local
31+
from tuning.config import peft_config
32+
from tuning.utils import config_utils
33+
34+
35+
def test_get_hf_peft_config_returns_None_for_tuning_config_None():
36+
"""Test that when tuning_config is None, the function returns None"""
37+
expected_config = None
38+
assert expected_config == config_utils.get_hf_peft_config("", None, "")
39+
40+
41+
def test_get_hf_peft_config_returns_lora_config_correctly():
42+
"""Test that tuning_config fields are passed to LoraConfig correctly,
43+
If not defined, the default values are used
44+
"""
45+
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)
46+
47+
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
48+
assert isinstance(config, LoraConfig)
49+
assert config.task_type == "CAUSAL_LM"
50+
assert config.r == 3
51+
assert config.lora_alpha == 3
52+
assert (
53+
config.lora_dropout == 0.05
54+
) # default value from local peft_config.LoraConfig
55+
assert config.target_modules == {
56+
"q_proj",
57+
"v_proj",
58+
} # default value from local peft_config.LoraConfig
59+
assert config.init_lora_weights is True # default value from HF peft.LoraConfig
60+
assert (
61+
config.megatron_core == "megatron.core"
62+
) # default value from HF peft.LoraConfig
63+
64+
65+
def test_get_hf_peft_config_ignores_tokenizer_path_for_lora_config():
66+
"""Test that if tokenizer is given with a LoraConfig, it is ignored"""
67+
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)
68+
69+
config = config_utils.get_hf_peft_config(
70+
task_type="CAUSAL_LM",
71+
tuning_config=tuning_config,
72+
tokenizer_name_or_path="foo/bar/path",
73+
)
74+
assert isinstance(config, LoraConfig)
75+
assert config.task_type == "CAUSAL_LM"
76+
assert config.r == 3
77+
assert config.lora_alpha == 3
78+
assert not hasattr(config, "tokenizer_name_or_path")
79+
80+
81+
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear():
82+
"""Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """
83+
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"])
84+
85+
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
86+
assert isinstance(config, LoraConfig)
87+
assert config.target_modules == "all-linear"
88+
89+
90+
def test_get_hf_peft_config_returns_pt_config_correctly():
91+
"""Test that the prompt tuning config is set properly for each field
92+
When a value is not defined, the default values are used
93+
"""
94+
tuning_config = peft_config.PromptTuningConfig(num_virtual_tokens=12)
95+
96+
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path")
97+
assert isinstance(config, PromptTuningConfig)
98+
assert config.task_type == "CAUSAL_LM"
99+
assert (
100+
config.prompt_tuning_init == "TEXT"
101+
) # default value from local peft_config.PromptTuningConfig
102+
assert config.num_virtual_tokens == 12
103+
assert (
104+
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
105+
) # default value from local peft_config.PromptTuningConfig
106+
assert config.tokenizer_name_or_path == "foo/bar/path"
107+
assert config.num_layers is None # default value from HF peft.PromptTuningConfig
108+
assert (
109+
config.inference_mode is False
110+
) # default value from HF peft.PromptTuningConfig
111+
112+
113+
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path():
114+
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
115+
Reference:
116+
https://github.com/huggingface/peft/blob/main/src/peft/tuners/prompt_tuning/config.py#L73
117+
"""
118+
119+
# When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path
120+
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM")
121+
config = config_utils.get_hf_peft_config(
122+
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
123+
)
124+
assert isinstance(config, PromptTuningConfig)
125+
assert config.tokenizer_name_or_path is None
126+
127+
# When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None
128+
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT")
129+
with pytest.raises(ValueError) as err:
130+
config_utils.get_hf_peft_config(
131+
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
132+
)
133+
assert "tokenizer_name_or_path can't be None" in err.value
134+
135+
136+
def test_create_tuning_config_for_peft_method_lora():
137+
"""Test that LoraConfig is created for peft_method Lora
138+
and fields are set properly.
139+
If unknown fields are passed, they are ignored
140+
"""
141+
tune_config = config_utils.create_tuning_config("lora", foo="x", r=234)
142+
assert isinstance(tune_config, peft_config.LoraConfig)
143+
assert tune_config.r == 234
144+
assert tune_config.lora_alpha == 32
145+
assert tune_config.lora_dropout == 0.05
146+
assert not hasattr(tune_config, "foo")
147+
148+
149+
def test_create_tuning_config_for_peft_method_pt():
150+
"""Test that PromptTuningConfig is created for peft_method pt
151+
and fields are set properly
152+
"""
153+
tune_config = config_utils.create_tuning_config(
154+
"pt", foo="x", prompt_tuning_init="RANDOM"
155+
)
156+
assert isinstance(tune_config, peft_config.PromptTuningConfig)
157+
assert tune_config.prompt_tuning_init == "RANDOM"
158+
159+
160+
def test_create_tuning_config_for_peft_method_none():
161+
"""Test that PromptTuningConfig is created for peft_method "None" or None"""
162+
tune_config = config_utils.create_tuning_config("None")
163+
assert tune_config is None
164+
165+
tune_config = config_utils.create_tuning_config(None)
166+
assert tune_config is None
167+
168+
169+
def test_create_tuning_config_does_not_recognize_any_other_peft_method():
170+
"""Test that PromptTuningConfig is created for peft_method "None" or None,
171+
"lora" or "pt", and no other
172+
"""
173+
with pytest.raises(AssertionError) as err:
174+
config_utils.create_tuning_config("hello", foo="x")
175+
assert err.value == "peft config hello not defined in peft.py"
176+
177+
178+
def test_update_config_can_handle_dot_for_nested_field():
179+
"""Test that the function can read dotted field for kwargs fields"""
180+
config = peft_config.LoraConfig(r=5)
181+
assert config.lora_alpha == 32 # default value is 32
182+
183+
# update lora_alpha to 98
184+
kwargs = {"LoraConfig.lora_alpha": 98}
185+
config_utils.update_config(config, **kwargs)
186+
assert config.lora_alpha == 98
187+
188+
189+
def test_update_config_does_nothing_for_unknown_field():
190+
"""Test that the function does not change other config
191+
field values if a kwarg field is unknown
192+
"""
193+
# foobar is an unknown field
194+
config = peft_config.LoraConfig(r=5)
195+
kwargs = {"LoraConfig.foobar": 98}
196+
config_utils.update_config(config, **kwargs) # nothing happens
197+
assert config.r == 5 # did not change r value
198+
assert not hasattr(config, "foobar")
199+
200+
201+
def test_update_config_can_handle_multiple_config_updates():
202+
"""Test that the function can handle a tuple of configs"""
203+
config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7))
204+
kwargs = {"r": 98}
205+
config_utils.update_config(config, **kwargs)
206+
assert config[0].r == 98
207+
assert config[1].r == 98
208+
209+
210+
def test_get_json_config_can_load_from_path():
211+
"""Test that the function get_json_config can read
212+
the json path from env var SFT_TRAINER_CONFIG_JSON_PATH
213+
"""
214+
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ:
215+
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"]
216+
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH
217+
218+
job_config = config_utils.get_json_config()
219+
assert job_config is not None
220+
assert job_config["model_name_or_path"] == "bigscience/bloom-560m"
221+
222+
223+
def test_get_json_config_can_load_from_envvar():
224+
"""Test that the function get_json_config can read
225+
the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR
226+
"""
227+
config_json = {"model_name_or_path": "foobar"}
228+
message_bytes = pickle.dumps(config_json)
229+
base64_bytes = base64.b64encode(message_bytes)
230+
encoded_json = base64_bytes.decode("ascii")
231+
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = encoded_json
232+
233+
job_config = config_utils.get_json_config()
234+
assert job_config is not None
235+
assert job_config["model_name_or_path"] == "foobar"

0 commit comments

Comments
 (0)