21
21
import pickle
22
22
23
23
# Third Party
24
- import pytest
25
24
from peft import LoraConfig , PromptTuningConfig
25
+ import pytest
26
26
27
+ # First Party
28
+ from tests .build .test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
27
29
28
30
# Local
29
- from tuning .utils import config_utils
30
31
from tuning .config import peft_config
31
- from tests .build .test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
32
+ from tuning .utils import config_utils
33
+
32
34
33
35
def test_get_hf_peft_config_returns_None_for_FT ():
34
36
expected_config = None
35
37
assert expected_config == config_utils .get_hf_peft_config ("" , None , "" )
36
38
39
+
37
40
def test_get_hf_peft_config_returns_Lora_config_correctly ():
38
41
# Test that when a value is not defined, the default values are used
39
42
# Default values: r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"]
@@ -44,8 +47,11 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
44
47
assert config .task_type == "CAUSAL_LM"
45
48
assert config .r == 3
46
49
assert config .lora_alpha == 3
47
- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
48
- assert config .target_modules == {'q_proj' , 'v_proj' } # default value from peft_config.LoraConfig
50
+ assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
51
+ assert config .target_modules == {
52
+ "q_proj" ,
53
+ "v_proj" ,
54
+ } # default value from peft_config.LoraConfig
49
55
50
56
# Test that when target_modules is ["all-linear"], we convert it to str type "all-linear"
51
57
tuning_config = peft_config .LoraConfig (r = 234 , target_modules = ["all-linear" ])
@@ -54,7 +60,8 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
54
60
assert isinstance (config , LoraConfig )
55
61
assert config .r == 234
56
62
assert config .target_modules == "all-linear"
57
- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
63
+ assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
64
+
58
65
59
66
def test_get_hf_peft_config_returns_PT_config_correctly ():
60
67
# Test that the prompt tuning config is set properly for each field
@@ -69,7 +76,9 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
69
76
assert config .task_type == "CAUSAL_LM"
70
77
assert config .prompt_tuning_init == "TEXT"
71
78
assert config .num_virtual_tokens == 12
72
- assert config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
79
+ assert (
80
+ config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81
+ )
73
82
assert config .tokenizer_name_or_path == "foo/bar/path"
74
83
75
84
# Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
@@ -87,65 +96,70 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
87
96
def test_create_tuning_config ():
88
97
# Test that LoraConfig is created for peft_method Lora
89
98
# and fields are set properly
90
- tune_config = config_utils .create_tuning_config ("lora" , foo = "x" , r = 234 )
99
+ tune_config = config_utils .create_tuning_config ("lora" , foo = "x" , r = 234 )
91
100
assert isinstance (tune_config , peft_config .LoraConfig )
92
101
assert tune_config .r == 234
93
102
assert tune_config .lora_alpha == 32
94
103
assert tune_config .lora_dropout == 0.05
95
104
96
105
# Test that PromptTuningConfig is created for peft_method pt
97
106
# and fields are set properly
98
- tune_config = config_utils .create_tuning_config ("pt" , foo = "x" , prompt_tuning_init = "RANDOM" )
107
+ tune_config = config_utils .create_tuning_config (
108
+ "pt" , foo = "x" , prompt_tuning_init = "RANDOM"
109
+ )
99
110
assert isinstance (tune_config , peft_config .PromptTuningConfig )
100
111
assert tune_config .prompt_tuning_init == "RANDOM"
101
112
102
113
# Test that None is created for peft_method "None" or None
103
114
# and fields are set properly
104
- tune_config = config_utils .create_tuning_config ("None" , foo = "x" )
115
+ tune_config = config_utils .create_tuning_config ("None" , foo = "x" )
105
116
assert tune_config is None
106
117
107
- tune_config = config_utils .create_tuning_config (None , foo = "x" )
118
+ tune_config = config_utils .create_tuning_config (None , foo = "x" )
108
119
assert tune_config is None
109
120
110
121
# Test that this function does not recognize any other peft_method
111
122
with pytest .raises (AssertionError ) as err :
112
- tune_config = config_utils .create_tuning_config ("hello" , foo = "x" )
123
+ tune_config = config_utils .create_tuning_config ("hello" , foo = "x" )
113
124
assert err .value == "peft config hello not defined in peft.py"
114
125
126
+
115
127
def test_update_config_can_handle_dot_for_nested_field ():
116
128
# Test update_config allows nested field
117
- config = peft_config .LoraConfig (r = 5 )
118
- assert config .lora_alpha == 32 # default value is 32
129
+ config = peft_config .LoraConfig (r = 5 )
130
+ assert config .lora_alpha == 32 # default value is 32
119
131
120
132
# update lora_alpha to 98
121
- kwargs = {' LoraConfig.lora_alpha' : 98 }
133
+ kwargs = {" LoraConfig.lora_alpha" : 98 }
122
134
config_utils .update_config (config , ** kwargs )
123
135
assert config .lora_alpha == 98
124
136
125
137
# update an unknown field
126
- kwargs = {'LoraConfig.foobar' : 98 }
127
- config_utils .update_config (config , ** kwargs ) # nothing happens
138
+ kwargs = {"LoraConfig.foobar" : 98 }
139
+ config_utils .update_config (config , ** kwargs ) # nothing happens
140
+
128
141
129
142
def test_update_config_can_handle_multiple_config_updates ():
130
143
# update a tuple of configs
131
- config = (peft_config .LoraConfig (r = 5 ), peft_config .LoraConfig (r = 7 ))
132
- kwargs = {'r' : 98 }
144
+ config = (peft_config .LoraConfig (r = 5 ), peft_config .LoraConfig (r = 7 ))
145
+ kwargs = {"r" : 98 }
133
146
config_utils .update_config (config , ** kwargs )
134
147
assert config [0 ].r == 98
135
148
assert config [1 ].r == 98
136
149
150
+
137
151
def test_get_json_config_can_load_from_path_or_envvar ():
138
152
# Load from path
139
153
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os .environ :
140
- del os .environ [' SFT_TRAINER_CONFIG_JSON_ENV_VAR' ]
154
+ del os .environ [" SFT_TRAINER_CONFIG_JSON_ENV_VAR" ]
141
155
os .environ ["SFT_TRAINER_CONFIG_JSON_PATH" ] = HAPPY_PATH_DUMMY_CONFIG_PATH
142
156
143
157
job_config = config_utils .get_json_config ()
144
158
assert job_config is not None
145
159
assert job_config ["model_name_or_path" ] == "bigscience/bloom-560m"
146
160
147
161
# Load from envvar
148
- config_json = {' model_name_or_path' : ' foobar' }
162
+ config_json = {" model_name_or_path" : " foobar" }
149
163
message_bytes = pickle .dumps (config_json )
150
164
base64_bytes = base64 .b64encode (message_bytes )
151
165
encoded_json = base64_bytes .decode ("ascii" )
0 commit comments