32
32
from tuning .utils import config_utils
33
33
34
34
35
- def test_get_hf_peft_config_returns_None_for_FT ():
35
+ def test_get_hf_peft_config_returns_None_for_tuning_config_None ():
36
+ """Test that when tuning_config is None, the function returns None"""
36
37
expected_config = None
37
38
assert expected_config == config_utils .get_hf_peft_config ("" , None , "" )
38
39
39
40
40
- def test_get_hf_peft_config_returns_Lora_config_correctly ():
41
- # Test that when a value is not defined, the default values are used
42
- # Default values: r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"]
41
+ def test_get_hf_peft_config_returns_lora_config_correctly ():
42
+ """Test that tuning_config fields are passed to LoraConfig correctly,
43
+ If not defined, the default values are used
44
+ """
43
45
tuning_config = peft_config .LoraConfig (r = 3 , lora_alpha = 3 )
44
46
45
47
config = config_utils .get_hf_peft_config ("CAUSAL_LM" , tuning_config , "" )
@@ -53,79 +55,93 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
53
55
"v_proj" ,
54
56
} # default value from peft_config.LoraConfig
55
57
56
- # Test that when target_modules is ["all-linear"], we convert it to str type "all-linear"
58
+
59
+ def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear ():
60
+ """Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """
57
61
tuning_config = peft_config .LoraConfig (r = 234 , target_modules = ["all-linear" ])
58
62
59
63
config = config_utils .get_hf_peft_config ("CAUSAL_LM" , tuning_config , "" )
60
64
assert isinstance (config , LoraConfig )
61
- assert config .r == 234
62
65
assert config .target_modules == "all-linear"
63
- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
64
66
65
67
66
- def test_get_hf_peft_config_returns_PT_config_correctly ():
67
- # Test that the prompt tuning config is set properly for each field
68
- # when a value is not defined, the default values are used
69
- # Default values:
70
- # prompt_tuning_init="TEXT",
71
- # prompt_tuning_init_text="Classify if the tweet is a complaint or not:"
68
+ def test_get_hf_peft_config_returns_pt_config_correctly ():
69
+ """Test that the prompt tuning config is set properly for each field
70
+ When a value is not defined, the default values are used
71
+ """
72
72
tuning_config = peft_config .PromptTuningConfig (num_virtual_tokens = 12 )
73
73
74
74
config = config_utils .get_hf_peft_config ("CAUSAL_LM" , tuning_config , "foo/bar/path" )
75
75
assert isinstance (config , PromptTuningConfig )
76
76
assert config .task_type == "CAUSAL_LM"
77
- assert config .prompt_tuning_init == "TEXT"
77
+ assert config .prompt_tuning_init == "TEXT" # default value
78
78
assert config .num_virtual_tokens == 12
79
79
assert (
80
80
config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81
- )
81
+ ) # default value
82
82
assert config .tokenizer_name_or_path == "foo/bar/path"
83
83
84
- # Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
84
+
85
+ def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path ():
86
+ """Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT"""
87
+
88
+ # When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path
85
89
tuning_config = peft_config .PromptTuningConfig (prompt_tuning_init = "RANDOM" )
86
90
config = config_utils .get_hf_peft_config (None , tuning_config , None )
87
91
assert isinstance (config , PromptTuningConfig )
88
92
assert config .tokenizer_name_or_path is None
89
93
94
+ # When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None
90
95
tuning_config = peft_config .PromptTuningConfig (prompt_tuning_init = "TEXT" )
91
96
with pytest .raises (ValueError ) as err :
92
97
config_utils .get_hf_peft_config (None , tuning_config , None )
93
98
assert "tokenizer_name_or_path can't be None" in err .value
94
99
95
100
96
- def test_create_tuning_config ():
97
- # Test that LoraConfig is created for peft_method Lora
98
- # and fields are set properly
101
+ def test_create_tuning_config_for_peft_method_lora ():
102
+ """Test that LoraConfig is created for peft_method Lora
103
+ and fields are set properly.
104
+ If unknown fields are passed, they are ignored
105
+ """
99
106
tune_config = config_utils .create_tuning_config ("lora" , foo = "x" , r = 234 )
100
107
assert isinstance (tune_config , peft_config .LoraConfig )
101
108
assert tune_config .r == 234
102
109
assert tune_config .lora_alpha == 32
103
110
assert tune_config .lora_dropout == 0.05
111
+ assert not hasattr (tune_config , "foo" )
112
+
104
113
105
- # Test that PromptTuningConfig is created for peft_method pt
106
- # and fields are set properly
114
+ def test_create_tuning_config_for_peft_method_pt ():
115
+ """Test that PromptTuningConfig is created for peft_method pt
116
+ and fields are set properly
117
+ """
107
118
tune_config = config_utils .create_tuning_config (
108
119
"pt" , foo = "x" , prompt_tuning_init = "RANDOM"
109
120
)
110
121
assert isinstance (tune_config , peft_config .PromptTuningConfig )
111
122
assert tune_config .prompt_tuning_init == "RANDOM"
112
123
113
- # Test that None is created for peft_method "None" or None
114
- # and fields are set properly
115
- tune_config = config_utils .create_tuning_config ("None" , foo = "x" )
124
+
125
+ def test_create_tuning_config_for_peft_method_none ():
126
+ """Test that PromptTuningConfig is created for peft_method "None" or None"""
127
+ tune_config = config_utils .create_tuning_config ("None" )
116
128
assert tune_config is None
117
129
118
- tune_config = config_utils .create_tuning_config (None , foo = "x" )
130
+ tune_config = config_utils .create_tuning_config (None )
119
131
assert tune_config is None
120
132
121
- # Test that this function does not recognize any other peft_method
133
+
134
+ def test_create_tuning_config_does_not_recognize_any_other_peft_method ():
135
+ """Test that PromptTuningConfig is created for peft_method "None" or None,
136
+ "lora" or "pt", and no other
137
+ """
122
138
with pytest .raises (AssertionError ) as err :
123
- tune_config = config_utils .create_tuning_config ("hello" , foo = "x" )
139
+ config_utils .create_tuning_config ("hello" , foo = "x" )
124
140
assert err .value == "peft config hello not defined in peft.py"
125
141
126
142
127
143
def test_update_config_can_handle_dot_for_nested_field ():
128
- # Test update_config allows nested field
144
+ """ Test that the function can read dotted field for kwargs fields"""
129
145
config = peft_config .LoraConfig (r = 5 )
130
146
assert config .lora_alpha == 32 # default value is 32
131
147
@@ -134,22 +150,32 @@ def test_update_config_can_handle_dot_for_nested_field():
134
150
config_utils .update_config (config , ** kwargs )
135
151
assert config .lora_alpha == 98
136
152
137
- # update an unknown field
153
+
154
+ def test_update_config_does_nothing_for_unknown_field ():
155
+ """Test that the function does not change other config
156
+ field values if a kwarg field is unknown
157
+ """
158
+ # foobar is an unknown field
159
+ config = peft_config .LoraConfig (r = 5 )
138
160
kwargs = {"LoraConfig.foobar" : 98 }
139
161
config_utils .update_config (config , ** kwargs ) # nothing happens
162
+ assert config .r == 5 # did not change r value
163
+ assert not hasattr (config , "foobar" )
140
164
141
165
142
166
def test_update_config_can_handle_multiple_config_updates ():
143
- # update a tuple of configs
167
+ """Test that the function can handle a tuple of configs"""
144
168
config = (peft_config .LoraConfig (r = 5 ), peft_config .LoraConfig (r = 7 ))
145
169
kwargs = {"r" : 98 }
146
170
config_utils .update_config (config , ** kwargs )
147
171
assert config [0 ].r == 98
148
172
assert config [1 ].r == 98
149
173
150
174
151
- def test_get_json_config_can_load_from_path_or_envvar ():
152
- # Load from path
175
+ def test_get_json_config_can_load_from_path ():
176
+ """Test that the function get_json_config can read
177
+ the json path from env var SFT_TRAINER_CONFIG_JSON_PATH
178
+ """
153
179
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os .environ :
154
180
del os .environ ["SFT_TRAINER_CONFIG_JSON_ENV_VAR" ]
155
181
os .environ ["SFT_TRAINER_CONFIG_JSON_PATH" ] = HAPPY_PATH_DUMMY_CONFIG_PATH
@@ -158,7 +184,11 @@ def test_get_json_config_can_load_from_path_or_envvar():
158
184
assert job_config is not None
159
185
assert job_config ["model_name_or_path" ] == "bigscience/bloom-560m"
160
186
161
- # Load from envvar
187
+
188
+ def test_get_json_config_can_load_from_envvar ():
189
+ """Test that the function get_json_config can read
190
+ the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR
191
+ """
162
192
config_json = {"model_name_or_path" : "foobar" }
163
193
message_bytes = pickle .dumps (config_json )
164
194
base64_bytes = base64 .b64encode (message_bytes )
0 commit comments