-
Notifications
You must be signed in to change notification settings - Fork 56
Add config_utils tests #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a8bd8dd
7054070
906ce02
3e831c6
dc32eda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
# Copyright The FMS HF Tuning Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
# https://spdx.dev/learn/handling-license-info/ | ||
|
||
# Standard | ||
import base64 | ||
import os | ||
import pickle | ||
|
||
# Third Party | ||
from peft import LoraConfig, PromptTuningConfig | ||
import pytest | ||
|
||
# First Party | ||
from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH | ||
|
||
# Local | ||
from tuning.config import peft_config | ||
from tuning.utils import config_utils | ||
|
||
|
||
def test_get_hf_peft_config_returns_None_for_tuning_config_None(): | ||
"""Test that when tuning_config is None, the function returns None""" | ||
expected_config = None | ||
assert expected_config == config_utils.get_hf_peft_config("", None, "") | ||
|
||
|
||
def test_get_hf_peft_config_returns_lora_config_correctly(): | ||
"""Test that tuning_config fields are passed to LoraConfig correctly, | ||
If not defined, the default values are used | ||
""" | ||
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3) | ||
|
||
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "") | ||
assert isinstance(config, LoraConfig) | ||
assert config.task_type == "CAUSAL_LM" | ||
assert config.r == 3 | ||
assert config.lora_alpha == 3 | ||
assert ( | ||
config.lora_dropout == 0.05 | ||
) # default value from local peft_config.LoraConfig | ||
assert config.target_modules == { | ||
aluu317 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"q_proj", | ||
"v_proj", | ||
} # default value from local peft_config.LoraConfig | ||
assert config.init_lora_weights is True # default value from HF peft.LoraConfig | ||
assert ( | ||
config.megatron_core == "megatron.core" | ||
) # default value from HF peft.LoraConfig | ||
|
||
|
||
def test_get_hf_peft_config_ignores_tokenizer_path_for_lora_config(): | ||
"""Test that if tokenizer is given with a LoraConfig, it is ignored""" | ||
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3) | ||
|
||
config = config_utils.get_hf_peft_config( | ||
task_type="CAUSAL_LM", | ||
tuning_config=tuning_config, | ||
tokenizer_name_or_path="foo/bar/path", | ||
) | ||
assert isinstance(config, LoraConfig) | ||
assert config.task_type == "CAUSAL_LM" | ||
assert config.r == 3 | ||
assert config.lora_alpha == 3 | ||
assert not hasattr(config, "tokenizer_name_or_path") | ||
|
||
|
||
def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear(): | ||
"""Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """ | ||
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"]) | ||
|
||
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "") | ||
assert isinstance(config, LoraConfig) | ||
assert config.target_modules == "all-linear" | ||
|
||
|
||
def test_get_hf_peft_config_returns_pt_config_correctly(): | ||
"""Test that the prompt tuning config is set properly for each field | ||
When a value is not defined, the default values are used | ||
""" | ||
tuning_config = peft_config.PromptTuningConfig(num_virtual_tokens=12) | ||
|
||
config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path") | ||
assert isinstance(config, PromptTuningConfig) | ||
assert config.task_type == "CAUSAL_LM" | ||
assert ( | ||
config.prompt_tuning_init == "TEXT" | ||
) # default value from local peft_config.PromptTuningConfig | ||
assert config.num_virtual_tokens == 12 | ||
assert ( | ||
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:" | ||
) # default value from local peft_config.PromptTuningConfig | ||
assert config.tokenizer_name_or_path == "foo/bar/path" | ||
aluu317 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert config.num_layers is None # default value from HF peft.PromptTuningConfig | ||
assert ( | ||
config.inference_mode is False | ||
) # default value from HF peft.PromptTuningConfig | ||
|
||
|
||
def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path(): | ||
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT | ||
Reference: | ||
https://github.com/huggingface/peft/blob/main/src/peft/tuners/prompt_tuning/config.py#L73 | ||
""" | ||
|
||
# When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path | ||
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM") | ||
config = config_utils.get_hf_peft_config( | ||
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None | ||
) | ||
assert isinstance(config, PromptTuningConfig) | ||
assert config.tokenizer_name_or_path is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting! So that means others like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the function call to be more readable. The function call is
So the Other fields are not affected. |
||
|
||
# When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None | ||
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT") | ||
with pytest.raises(ValueError) as err: | ||
config_utils.get_hf_peft_config( | ||
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None | ||
) | ||
assert "tokenizer_name_or_path can't be None" in err.value | ||
|
||
|
||
def test_create_tuning_config_for_peft_method_lora(): | ||
"""Test that LoraConfig is created for peft_method Lora | ||
and fields are set properly. | ||
If unknown fields are passed, they are ignored | ||
""" | ||
tune_config = config_utils.create_tuning_config("lora", foo="x", r=234) | ||
aluu317 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert isinstance(tune_config, peft_config.LoraConfig) | ||
assert tune_config.r == 234 | ||
assert tune_config.lora_alpha == 32 | ||
assert tune_config.lora_dropout == 0.05 | ||
assert not hasattr(tune_config, "foo") | ||
|
||
|
||
def test_create_tuning_config_for_peft_method_pt(): | ||
"""Test that PromptTuningConfig is created for peft_method pt | ||
and fields are set properly | ||
""" | ||
tune_config = config_utils.create_tuning_config( | ||
"pt", foo="x", prompt_tuning_init="RANDOM" | ||
) | ||
assert isinstance(tune_config, peft_config.PromptTuningConfig) | ||
assert tune_config.prompt_tuning_init == "RANDOM" | ||
|
||
|
||
def test_create_tuning_config_for_peft_method_none(): | ||
"""Test that PromptTuningConfig is created for peft_method "None" or None""" | ||
tune_config = config_utils.create_tuning_config("None") | ||
assert tune_config is None | ||
|
||
tune_config = config_utils.create_tuning_config(None) | ||
assert tune_config is None | ||
|
||
|
||
def test_create_tuning_config_does_not_recognize_any_other_peft_method(): | ||
"""Test that PromptTuningConfig is created for peft_method "None" or None, | ||
"lora" or "pt", and no other | ||
""" | ||
with pytest.raises(AssertionError) as err: | ||
config_utils.create_tuning_config("hello", foo="x") | ||
assert err.value == "peft config hello not defined in peft.py" | ||
|
||
|
||
def test_update_config_can_handle_dot_for_nested_field(): | ||
"""Test that the function can read dotted field for kwargs fields""" | ||
config = peft_config.LoraConfig(r=5) | ||
assert config.lora_alpha == 32 # default value is 32 | ||
|
||
# update lora_alpha to 98 | ||
kwargs = {"LoraConfig.lora_alpha": 98} | ||
config_utils.update_config(config, **kwargs) | ||
assert config.lora_alpha == 98 | ||
|
||
|
||
def test_update_config_does_nothing_for_unknown_field(): | ||
"""Test that the function does not change other config | ||
field values if a kwarg field is unknown | ||
""" | ||
# foobar is an unknown field | ||
config = peft_config.LoraConfig(r=5) | ||
kwargs = {"LoraConfig.foobar": 98} | ||
config_utils.update_config(config, **kwargs) # nothing happens | ||
aluu317 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert config.r == 5 # did not change r value | ||
assert not hasattr(config, "foobar") | ||
|
||
|
||
def test_update_config_can_handle_multiple_config_updates(): | ||
"""Test that the function can handle a tuple of configs""" | ||
config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7)) | ||
kwargs = {"r": 98} | ||
config_utils.update_config(config, **kwargs) | ||
assert config[0].r == 98 | ||
assert config[1].r == 98 | ||
aluu317 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_get_json_config_can_load_from_path(): | ||
"""Test that the function get_json_config can read | ||
the json path from env var SFT_TRAINER_CONFIG_JSON_PATH | ||
""" | ||
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ: | ||
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] | ||
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH | ||
|
||
job_config = config_utils.get_json_config() | ||
assert job_config is not None | ||
assert job_config["model_name_or_path"] == "bigscience/bloom-560m" | ||
|
||
|
||
def test_get_json_config_can_load_from_envvar(): | ||
"""Test that the function get_json_config can read | ||
the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR | ||
""" | ||
config_json = {"model_name_or_path": "foobar"} | ||
message_bytes = pickle.dumps(config_json) | ||
base64_bytes = base64.b64encode(message_bytes) | ||
encoded_json = base64_bytes.decode("ascii") | ||
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = encoded_json | ||
|
||
job_config = config_utils.get_json_config() | ||
assert job_config is not None | ||
assert job_config["model_name_or_path"] == "foobar" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another test could be, what happens if a tokenizer is passed? Since
tokenizer_name_or_path
is by default set tomodel_name_or_path
, then this value will always get passed when it's run in sft_trainer.py. I believe tokenizer is used in PromptTuning but would be ignored for LoRA?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only comment that I think wasn't addressed, what do you think about adding a test where you run with lora and pass in a tokenizer (which we expect to be ignored)?