Skip to content

Commit dc32eda

Browse files
committed
Add test for tokenizer in lora config (should be ignored)
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
1 parent 3e831c6 commit dc32eda

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

tests/utils/test_config_utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,22 @@ def test_get_hf_peft_config_returns_lora_config_correctly():
6262
) # default value from HF peft.LoraConfig
6363

6464

65+
def test_get_hf_peft_config_ignores_tokenizer_path_for_lora_config():
66+
"""Test that if tokenizer is given with a LoraConfig, it is ignored"""
67+
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)
68+
69+
config = config_utils.get_hf_peft_config(
70+
task_type="CAUSAL_LM",
71+
tuning_config=tuning_config,
72+
tokenizer_name_or_path="foo/bar/path",
73+
)
74+
assert isinstance(config, LoraConfig)
75+
assert config.task_type == "CAUSAL_LM"
76+
assert config.r == 3
77+
assert config.lora_alpha == 3
78+
assert not hasattr(config, "tokenizer_name_or_path")
79+
80+
6581
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear():
6682
"""Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """
6783
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"])
@@ -95,18 +111,25 @@ def test_get_hf_peft_config_returns_pt_config_correctly():
95111

96112

97113
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path():
98-
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT"""
114+
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
115+
Reference:
116+
https://github.com/huggingface/peft/blob/main/src/peft/tuners/prompt_tuning/config.py#L73
117+
"""
99118

100119
# When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path
101120
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM")
102-
config = config_utils.get_hf_peft_config(None, tuning_config, None)
121+
config = config_utils.get_hf_peft_config(
122+
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
123+
)
103124
assert isinstance(config, PromptTuningConfig)
104125
assert config.tokenizer_name_or_path is None
105126

106127
# When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None
107128
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT")
108129
with pytest.raises(ValueError) as err:
109-
config_utils.get_hf_peft_config(None, tuning_config, None)
130+
config_utils.get_hf_peft_config(
131+
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
132+
)
110133
assert "tokenizer_name_or_path can't be None" in err.value
111134

112135

0 commit comments

Comments
 (0)