Skip to content

Commit 3e831c6

Browse files
committed
Update more field/value checks from HF defaults
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
1 parent 906ce02 commit 3e831c6

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

tests/utils/test_config_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,17 @@ def test_get_hf_peft_config_returns_lora_config_correctly():
4949
assert config.task_type == "CAUSAL_LM"
5050
assert config.r == 3
5151
assert config.lora_alpha == 3
52-
assert config.lora_dropout == 0.05 # default value from peft_config.LoraConfig
52+
assert (
53+
config.lora_dropout == 0.05
54+
) # default value from local peft_config.LoraConfig
5355
assert config.target_modules == {
5456
"q_proj",
5557
"v_proj",
56-
} # default value from peft_config.LoraConfig
58+
} # default value from local peft_config.LoraConfig
59+
assert config.init_lora_weights is True # default value from HF peft.LoraConfig
60+
assert (
61+
config.megatron_core == "megatron.core"
62+
) # default value from HF peft.LoraConfig
5763

5864

5965
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear():
@@ -74,12 +80,18 @@ def test_get_hf_peft_config_returns_pt_config_correctly():
7480
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path")
7581
assert isinstance(config, PromptTuningConfig)
7682
assert config.task_type == "CAUSAL_LM"
77-
assert config.prompt_tuning_init == "TEXT" # default value
83+
assert (
84+
config.prompt_tuning_init == "TEXT"
85+
) # default value from local peft_config.PromptTuningConfig
7886
assert config.num_virtual_tokens == 12
7987
assert (
8088
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81-
) # default value
89+
) # default value from local peft_config.PromptTuningConfig
8290
assert config.tokenizer_name_or_path == "foo/bar/path"
91+
assert config.num_layers is None # default value from HF peft.PromptTuningConfig
92+
assert (
93+
config.inference_mode is False
94+
) # default value from HF peft.PromptTuningConfig
8395

8496

8597
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path():

0 commit comments

Comments
 (0)