Skip to content

Commit 5495027

Browse files
committed
Fix fmt
1 parent 67f97e2 commit 5495027

File tree

1 file changed

+35
-21
lines changed

1 file changed

+35
-21
lines changed

tests/utils/test_config_utils.py

+35-21
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,22 @@
2121
import pickle
2222

2323
# Third Party
24-
import pytest
2524
from peft import LoraConfig, PromptTuningConfig
25+
import pytest
2626

27+
# First Party
28+
from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
2729

2830
# Local
29-
from tuning.utils import config_utils
3031
from tuning.config import peft_config
31-
from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
32+
from tuning.utils import config_utils
33+
3234

3335
def test_get_hf_peft_config_returns_None_for_FT():
3436
expected_config = None
3537
assert expected_config == config_utils.get_hf_peft_config("", None, "")
3638

39+
3740
def test_get_hf_peft_config_returns_Lora_config_correctly():
3841
# Test that when a value is not defined, the default values are used
3942
# Default values: r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"]
@@ -44,8 +47,11 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
4447
assert config.task_type == "CAUSAL_LM"
4548
assert config.r == 3
4649
assert config.lora_alpha == 3
47-
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
48-
assert config.target_modules == {'q_proj', 'v_proj'} # default value from peft_config.LoraConfig
50+
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
51+
assert config.target_modules == {
52+
"q_proj",
53+
"v_proj",
54+
} # default value from peft_config.LoraConfig
4955

5056
# Test that when target_modules is ["all-linear"], we convert it to str type "all-linear"
5157
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"])
@@ -54,7 +60,8 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
5460
assert isinstance(config, LoraConfig)
5561
assert config.r == 234
5662
assert config.target_modules == "all-linear"
57-
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
63+
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
64+
5865

5966
def test_get_hf_peft_config_returns_PT_config_correctly():
6067
# Test that the prompt tuning config is set properly for each field
@@ -69,7 +76,9 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
6976
assert config.task_type == "CAUSAL_LM"
7077
assert config.prompt_tuning_init == "TEXT"
7178
assert config.num_virtual_tokens == 12
72-
assert config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
79+
assert (
80+
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81+
)
7382
assert config.tokenizer_name_or_path == "foo/bar/path"
7483

7584
# Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
@@ -87,65 +96,70 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
8796
def test_create_tuning_config():
8897
# Test that LoraConfig is created for peft_method Lora
8998
# and fields are set properly
90-
tune_config = config_utils.create_tuning_config("lora", foo= "x", r= 234)
99+
tune_config = config_utils.create_tuning_config("lora", foo="x", r=234)
91100
assert isinstance(tune_config, peft_config.LoraConfig)
92101
assert tune_config.r == 234
93102
assert tune_config.lora_alpha == 32
94103
assert tune_config.lora_dropout == 0.05
95104

96105
# Test that PromptTuningConfig is created for peft_method pt
97106
# and fields are set properly
98-
tune_config = config_utils.create_tuning_config("pt", foo= "x", prompt_tuning_init= "RANDOM")
107+
tune_config = config_utils.create_tuning_config(
108+
"pt", foo="x", prompt_tuning_init="RANDOM"
109+
)
99110
assert isinstance(tune_config, peft_config.PromptTuningConfig)
100111
assert tune_config.prompt_tuning_init == "RANDOM"
101112

102113
# Test that None is created for peft_method "None" or None
103114
# and fields are set properly
104-
tune_config = config_utils.create_tuning_config("None", foo= "x")
115+
tune_config = config_utils.create_tuning_config("None", foo="x")
105116
assert tune_config is None
106117

107-
tune_config = config_utils.create_tuning_config(None, foo= "x")
118+
tune_config = config_utils.create_tuning_config(None, foo="x")
108119
assert tune_config is None
109120

110121
# Test that this function does not recognize any other peft_method
111122
with pytest.raises(AssertionError) as err:
112-
tune_config = config_utils.create_tuning_config("hello", foo = "x")
123+
tune_config = config_utils.create_tuning_config("hello", foo="x")
113124
assert err.value == "peft config hello not defined in peft.py"
114125

126+
115127
def test_update_config_can_handle_dot_for_nested_field():
116128
# Test update_config allows nested field
117-
config = peft_config.LoraConfig(r = 5)
118-
assert config.lora_alpha == 32 # default value is 32
129+
config = peft_config.LoraConfig(r=5)
130+
assert config.lora_alpha == 32 # default value is 32
119131

120132
# update lora_alpha to 98
121-
kwargs = {'LoraConfig.lora_alpha': 98}
133+
kwargs = {"LoraConfig.lora_alpha": 98}
122134
config_utils.update_config(config, **kwargs)
123135
assert config.lora_alpha == 98
124136

125137
# update an unknown field
126-
kwargs = {'LoraConfig.foobar': 98}
127-
config_utils.update_config(config, **kwargs) # nothing happens
138+
kwargs = {"LoraConfig.foobar": 98}
139+
config_utils.update_config(config, **kwargs) # nothing happens
140+
128141

129142
def test_update_config_can_handle_multiple_config_updates():
130143
# update a tuple of configs
131-
config = (peft_config.LoraConfig(r = 5), peft_config.LoraConfig(r = 7))
132-
kwargs = {'r': 98}
144+
config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7))
145+
kwargs = {"r": 98}
133146
config_utils.update_config(config, **kwargs)
134147
assert config[0].r == 98
135148
assert config[1].r == 98
136149

150+
137151
def test_get_json_config_can_load_from_path_or_envvar():
138152
# Load from path
139153
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ:
140-
del os.environ['SFT_TRAINER_CONFIG_JSON_ENV_VAR']
154+
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"]
141155
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH
142156

143157
job_config = config_utils.get_json_config()
144158
assert job_config is not None
145159
assert job_config["model_name_or_path"] == "bigscience/bloom-560m"
146160

147161
# Load from envvar
148-
config_json = {'model_name_or_path': 'foobar'}
162+
config_json = {"model_name_or_path": "foobar"}
149163
message_bytes = pickle.dumps(config_json)
150164
base64_bytes = base64.b64encode(message_bytes)
151165
encoded_json = base64_bytes.decode("ascii")

0 commit comments

Comments
 (0)