Skip to content

Commit 906ce02

Browse files
committed
Separate tests out and use docstrings
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
1 parent 7054070 commit 906ce02

File tree

1 file changed

+63
-33
lines changed

1 file changed

+63
-33
lines changed

tests/utils/test_config_utils.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@
3232
from tuning.utils import config_utils
3333

3434

35-
def test_get_hf_peft_config_returns_None_for_FT():
35+
def test_get_hf_peft_config_returns_None_for_tuning_config_None():
36+
"""Test that when tuning_config is None, the function returns None"""
3637
expected_config = None
3738
assert expected_config == config_utils.get_hf_peft_config("", None, "")
3839

3940

40-
def test_get_hf_peft_config_returns_Lora_config_correctly():
41-
# Test that when a value is not defined, the default values are used
42-
# Default values: r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"]
41+
def test_get_hf_peft_config_returns_lora_config_correctly():
42+
"""Test that tuning_config fields are passed to LoraConfig correctly,
43+
If not defined, the default values are used
44+
"""
4345
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)
4446

4547
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
@@ -53,79 +55,93 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
5355
"v_proj",
5456
} # default value from peft_config.LoraConfig
5557

56-
# Test that when target_modules is ["all-linear"], we convert it to str type "all-linear"
58+
59+
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear():
60+
"""Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """
5761
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"])
5862

5963
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
6064
assert isinstance(config, LoraConfig)
61-
assert config.r == 234
6265
assert config.target_modules == "all-linear"
63-
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
6466

6567

66-
def test_get_hf_peft_config_returns_PT_config_correctly():
67-
# Test that the prompt tuning config is set properly for each field
68-
# when a value is not defined, the default values are used
69-
# Default values:
70-
# prompt_tuning_init="TEXT",
71-
# prompt_tuning_init_text="Classify if the tweet is a complaint or not:"
68+
def test_get_hf_peft_config_returns_pt_config_correctly():
69+
"""Test that the prompt tuning config is set properly for each field
70+
When a value is not defined, the default values are used
71+
"""
7272
tuning_config = peft_config.PromptTuningConfig(num_virtual_tokens=12)
7373

7474
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path")
7575
assert isinstance(config, PromptTuningConfig)
7676
assert config.task_type == "CAUSAL_LM"
77-
assert config.prompt_tuning_init == "TEXT"
77+
assert config.prompt_tuning_init == "TEXT" # default value
7878
assert config.num_virtual_tokens == 12
7979
assert (
8080
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81-
)
81+
) # default value
8282
assert config.tokenizer_name_or_path == "foo/bar/path"
8383

84-
# Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
84+
85+
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path():
86+
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT"""
87+
88+
# When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path
8589
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM")
8690
config = config_utils.get_hf_peft_config(None, tuning_config, None)
8791
assert isinstance(config, PromptTuningConfig)
8892
assert config.tokenizer_name_or_path is None
8993

94+
# When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None
9095
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT")
9196
with pytest.raises(ValueError) as err:
9297
config_utils.get_hf_peft_config(None, tuning_config, None)
9398
assert "tokenizer_name_or_path can't be None" in err.value
9499

95100

96-
def test_create_tuning_config():
97-
# Test that LoraConfig is created for peft_method Lora
98-
# and fields are set properly
101+
def test_create_tuning_config_for_peft_method_lora():
102+
"""Test that LoraConfig is created for peft_method Lora
103+
and fields are set properly.
104+
If unknown fields are passed, they are ignored
105+
"""
99106
tune_config = config_utils.create_tuning_config("lora", foo="x", r=234)
100107
assert isinstance(tune_config, peft_config.LoraConfig)
101108
assert tune_config.r == 234
102109
assert tune_config.lora_alpha == 32
103110
assert tune_config.lora_dropout == 0.05
111+
assert not hasattr(tune_config, "foo")
112+
104113

105-
# Test that PromptTuningConfig is created for peft_method pt
106-
# and fields are set properly
114+
def test_create_tuning_config_for_peft_method_pt():
115+
"""Test that PromptTuningConfig is created for peft_method pt
116+
and fields are set properly
117+
"""
107118
tune_config = config_utils.create_tuning_config(
108119
"pt", foo="x", prompt_tuning_init="RANDOM"
109120
)
110121
assert isinstance(tune_config, peft_config.PromptTuningConfig)
111122
assert tune_config.prompt_tuning_init == "RANDOM"
112123

113-
# Test that None is created for peft_method "None" or None
114-
# and fields are set properly
115-
tune_config = config_utils.create_tuning_config("None", foo="x")
124+
125+
def test_create_tuning_config_for_peft_method_none():
126+
"""Test that PromptTuningConfig is created for peft_method "None" or None"""
127+
tune_config = config_utils.create_tuning_config("None")
116128
assert tune_config is None
117129

118-
tune_config = config_utils.create_tuning_config(None, foo="x")
130+
tune_config = config_utils.create_tuning_config(None)
119131
assert tune_config is None
120132

121-
# Test that this function does not recognize any other peft_method
133+
134+
def test_create_tuning_config_does_not_recognize_any_other_peft_method():
135+
"""Test that PromptTuningConfig is created for peft_method "None" or None,
136+
"lora" or "pt", and no other
137+
"""
122138
with pytest.raises(AssertionError) as err:
123-
tune_config = config_utils.create_tuning_config("hello", foo="x")
139+
config_utils.create_tuning_config("hello", foo="x")
124140
assert err.value == "peft config hello not defined in peft.py"
125141

126142

127143
def test_update_config_can_handle_dot_for_nested_field():
128-
# Test update_config allows nested field
144+
"""Test that the function can read dotted field for kwargs fields"""
129145
config = peft_config.LoraConfig(r=5)
130146
assert config.lora_alpha == 32 # default value is 32
131147

@@ -134,22 +150,32 @@ def test_update_config_can_handle_dot_for_nested_field():
134150
config_utils.update_config(config, **kwargs)
135151
assert config.lora_alpha == 98
136152

137-
# update an unknown field
153+
154+
def test_update_config_does_nothing_for_unknown_field():
155+
"""Test that the function does not change other config
156+
field values if a kwarg field is unknown
157+
"""
158+
# foobar is an unknown field
159+
config = peft_config.LoraConfig(r=5)
138160
kwargs = {"LoraConfig.foobar": 98}
139161
config_utils.update_config(config, **kwargs) # nothing happens
162+
assert config.r == 5 # did not change r value
163+
assert not hasattr(config, "foobar")
140164

141165

142166
def test_update_config_can_handle_multiple_config_updates():
143-
# update a tuple of configs
167+
"""Test that the function can handle a tuple of configs"""
144168
config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7))
145169
kwargs = {"r": 98}
146170
config_utils.update_config(config, **kwargs)
147171
assert config[0].r == 98
148172
assert config[1].r == 98
149173

150174

151-
def test_get_json_config_can_load_from_path_or_envvar():
152-
# Load from path
175+
def test_get_json_config_can_load_from_path():
176+
"""Test that the function get_json_config can read
177+
the json path from env var SFT_TRAINER_CONFIG_JSON_PATH
178+
"""
153179
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ:
154180
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"]
155181
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH
@@ -158,7 +184,11 @@ def test_get_json_config_can_load_from_path_or_envvar():
158184
assert job_config is not None
159185
assert job_config["model_name_or_path"] == "bigscience/bloom-560m"
160186

161-
# Load from envvar
187+
188+
def test_get_json_config_can_load_from_envvar():
189+
"""Test that the function get_json_config can read
190+
the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR
191+
"""
162192
config_json = {"model_name_or_path": "foobar"}
163193
message_bytes = pickle.dumps(config_json)
164194
base64_bytes = base64.b64encode(message_bytes)

0 commit comments

Comments
 (0)