@@ -49,11 +49,17 @@ def test_get_hf_peft_config_returns_lora_config_correctly():
49
49
assert config .task_type == "CAUSAL_LM"
50
50
assert config .r == 3
51
51
assert config .lora_alpha == 3
52
- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
52
+ assert (
53
+ config .lora_dropout == 0.05
54
+ ) # default value from local peft_config.LoraConfig
53
55
assert config .target_modules == {
54
56
"q_proj" ,
55
57
"v_proj" ,
56
- } # default value from peft_config.LoraConfig
58
+ } # default value from local peft_config.LoraConfig
59
+ assert config .init_lora_weights is True # default value from HF peft.LoraConfig
60
+ assert (
61
+ config .megatron_core == "megatron.core"
62
+ ) # default value from HF peft.LoraConfig
57
63
58
64
59
65
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear ():
@@ -74,12 +80,18 @@ def test_get_hf_peft_config_returns_pt_config_correctly():
74
80
config = config_utils .get_hf_peft_config ("CAUSAL_LM" , tuning_config , "foo/bar/path" )
75
81
assert isinstance (config , PromptTuningConfig )
76
82
assert config .task_type == "CAUSAL_LM"
77
- assert config .prompt_tuning_init == "TEXT" # default value
83
+ assert (
84
+ config .prompt_tuning_init == "TEXT"
85
+ ) # default value from local peft_config.PromptTuningConfig
78
86
assert config .num_virtual_tokens == 12
79
87
assert (
80
88
config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81
- ) # default value
89
+ ) # default value from local peft_config.PromptTuningConfig
82
90
assert config .tokenizer_name_or_path == "foo/bar/path"
91
+ assert config .num_layers is None # default value from HF peft.PromptTuningConfig
92
+ assert (
93
+ config .inference_mode is False
94
+ ) # default value from HF peft.PromptTuningConfig
83
95
84
96
85
97
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path ():
0 commit comments