@@ -49,11 +49,17 @@ def test_get_hf_peft_config_returns_lora_config_correctly():
4949 assert config .task_type == "CAUSAL_LM"
5050 assert config .r == 3
5151 assert config .lora_alpha == 3
52- assert config .lora_dropout == 0.05 # default value from peft_config.LoraConfig
52+ assert (
53+ config .lora_dropout == 0.05
54+ ) # default value from local peft_config.LoraConfig
5355 assert config .target_modules == {
5456 "q_proj" ,
5557 "v_proj" ,
56- } # default value from peft_config.LoraConfig
58+ } # default value from local peft_config.LoraConfig
59+ assert config .init_lora_weights is True # default value from HF peft.LoraConfig
60+ assert (
61+ config .megatron_core == "megatron.core"
62+ ) # default value from HF peft.LoraConfig
5763
5864
5965def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear ():
@@ -74,12 +80,18 @@ def test_get_hf_peft_config_returns_pt_config_correctly():
7480 config = config_utils .get_hf_peft_config ("CAUSAL_LM" , tuning_config , "foo/bar/path" )
7581 assert isinstance (config , PromptTuningConfig )
7682 assert config .task_type == "CAUSAL_LM"
77- assert config .prompt_tuning_init == "TEXT" # default value
83+ assert (
84+ config .prompt_tuning_init == "TEXT"
85+ ) # default value from local peft_config.PromptTuningConfig
7886 assert config .num_virtual_tokens == 12
7987 assert (
8088 config .prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
81- ) # default value
89+ ) # default value from local peft_config.PromptTuningConfig
8290 assert config .tokenizer_name_or_path == "foo/bar/path"
91+ assert config .num_layers is None # default value from HF peft.PromptTuningConfig
92+ assert (
93+ config .inference_mode is False
94+ ) # default value from HF peft.PromptTuningConfig
8395
8496
8597def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path ():
0 commit comments