2121import pickle
2222
2323# Third Party
24- import pytest
2524from peft import LoraConfig , PromptTuningConfig
25+ import pytest
2626
27+ # First Party
28+ from tests .build .test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
2729
2830# Local
29- from tuning .utils import config_utils
3031from tuning .config import peft_config
31- from tests .build .test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH
32+ from tuning .utils import config_utils
33+
3234
3335def test_get_hf_peft_config_returns_None_for_FT ():
3436 expected_config = None
3537 assert expected_config == config_utils .get_hf_peft_config ("" , None , "" )
3638
39+
3740def test_get_hf_peft_config_returns_Lora_config_correctly ():
3841 # Test that when a value is not defined, the default values are used
3942 # Default values: r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["q_proj", "v_proj"]
@@ -44,8 +47,11 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
4447 assert config .task_type == "CAUSAL_LM"
4548 assert config .r == 3
4649 assert config .lora_alpha == 3
47- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
48- assert config .target_modules == {'q_proj' , 'v_proj' } # default value from peft_config.LoraConfig
50+ assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
51+ assert config .target_modules == {
52+ "q_proj" ,
53+ "v_proj" ,
54+ } # default value from peft_config.LoraConfig
4955
5056 # Test that when target_modules is ["all-linear"], we convert it to str type "all-linear"
5157 tuning_config = peft_config .LoraConfig (r = 234 , target_modules = ["all-linear" ])
@@ -54,7 +60,8 @@ def test_get_hf_peft_config_returns_Lora_config_correctly():
5460 assert isinstance (config , LoraConfig )
5561 assert config .r == 234
5662 assert config .target_modules == "all-linear"
57- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
63+ assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
64+
5865
5966def test_get_hf_peft_config_returns_PT_config_correctly ():
6067 # Test that the prompt tuning config is set properly for each field
@@ -69,7 +76,9 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
6976 assert config .task_type == "CAUSAL_LM"
7077 assert config .prompt_tuning_init == "TEXT"
7178 assert config .num_virtual_tokens == 12
72- assert config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
79+ assert (
80+ config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81+ )
7382 assert config .tokenizer_name_or_path == "foo/bar/path"
7483
7584 # Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
@@ -87,65 +96,70 @@ def test_get_hf_peft_config_returns_PT_config_correctly():
8796def test_create_tuning_config ():
8897 # Test that LoraConfig is created for peft_method Lora
8998 # and fields are set properly
90- tune_config = config_utils .create_tuning_config ("lora" , foo = "x" , r = 234 )
99+ tune_config = config_utils .create_tuning_config ("lora" , foo = "x" , r = 234 )
91100 assert isinstance (tune_config , peft_config .LoraConfig )
92101 assert tune_config .r == 234
93102 assert tune_config .lora_alpha == 32
94103 assert tune_config .lora_dropout == 0.05
95104
96105 # Test that PromptTuningConfig is created for peft_method pt
97106 # and fields are set properly
98- tune_config = config_utils .create_tuning_config ("pt" , foo = "x" , prompt_tuning_init = "RANDOM" )
107+ tune_config = config_utils .create_tuning_config (
108+ "pt" , foo = "x" , prompt_tuning_init = "RANDOM"
109+ )
99110 assert isinstance (tune_config , peft_config .PromptTuningConfig )
100111 assert tune_config .prompt_tuning_init == "RANDOM"
101112
102113 # Test that None is created for peft_method "None" or None
103114 # and fields are set properly
104- tune_config = config_utils .create_tuning_config ("None" , foo = "x" )
115+ tune_config = config_utils .create_tuning_config ("None" , foo = "x" )
105116 assert tune_config is None
106117
107- tune_config = config_utils .create_tuning_config (None , foo = "x" )
118+ tune_config = config_utils .create_tuning_config (None , foo = "x" )
108119 assert tune_config is None
109120
110121 # Test that this function does not recognize any other peft_method
111122 with pytest .raises (AssertionError ) as err :
112- tune_config = config_utils .create_tuning_config ("hello" , foo = "x" )
123+ tune_config = config_utils .create_tuning_config ("hello" , foo = "x" )
113124 assert err .value == "peft config hello not defined in peft.py"
114125
126+
115127def test_update_config_can_handle_dot_for_nested_field ():
116128 # Test update_config allows nested field
117- config = peft_config .LoraConfig (r = 5 )
118- assert config .lora_alpha == 32 # default value is 32
129+ config = peft_config .LoraConfig (r = 5 )
130+ assert config .lora_alpha == 32 # default value is 32
119131
120132 # update lora_alpha to 98
121- kwargs = {' LoraConfig.lora_alpha' : 98 }
133+ kwargs = {" LoraConfig.lora_alpha" : 98 }
122134 config_utils .update_config (config , ** kwargs )
123135 assert config .lora_alpha == 98
124136
125137 # update an unknown field
126- kwargs = {'LoraConfig.foobar' : 98 }
127- config_utils .update_config (config , ** kwargs ) # nothing happens
138+ kwargs = {"LoraConfig.foobar" : 98 }
139+ config_utils .update_config (config , ** kwargs ) # nothing happens
140+
128141
129142def test_update_config_can_handle_multiple_config_updates ():
130143 # update a tuple of configs
131- config = (peft_config .LoraConfig (r = 5 ), peft_config .LoraConfig (r = 7 ))
132- kwargs = {'r' : 98 }
144+ config = (peft_config .LoraConfig (r = 5 ), peft_config .LoraConfig (r = 7 ))
145+ kwargs = {"r" : 98 }
133146 config_utils .update_config (config , ** kwargs )
134147 assert config [0 ].r == 98
135148 assert config [1 ].r == 98
136149
150+
137151def test_get_json_config_can_load_from_path_or_envvar ():
138152 # Load from path
139153 if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os .environ :
140- del os .environ [' SFT_TRAINER_CONFIG_JSON_ENV_VAR' ]
154+ del os .environ [" SFT_TRAINER_CONFIG_JSON_ENV_VAR" ]
141155 os .environ ["SFT_TRAINER_CONFIG_JSON_PATH" ] = HAPPY_PATH_DUMMY_CONFIG_PATH
142156
143157 job_config = config_utils .get_json_config ()
144158 assert job_config is not None
145159 assert job_config ["model_name_or_path" ] == "bigscience/bloom-560m"
146160
147161 # Load from envvar
148- config_json = {' model_name_or_path' : ' foobar' }
162+ config_json = {" model_name_or_path" : " foobar" }
149163 message_bytes = pickle .dumps (config_json )
150164 base64_bytes = base64 .b64encode (message_bytes )
151165 encoded_json = base64_bytes .decode ("ascii" )
0 commit comments