Skip to content

Add config_utils tests #262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions tests/utils/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
import base64
import os
import pickle

# Third Party
from peft import LoraConfig, PromptTuningConfig
import pytest

# First Party
from tests.build.test_utils import HAPPY_PATH_DUMMY_CONFIG_PATH

# Local
from tuning.config import peft_config
from tuning.utils import config_utils


def test_get_hf_peft_config_returns_None_for_tuning_config_None():
"""Test that when tuning_config is None, the function returns None"""
expected_config = None
assert expected_config == config_utils.get_hf_peft_config("", None, "")


def test_get_hf_peft_config_returns_lora_config_correctly():
"""Test that tuning_config fields are passed to LoraConfig correctly,
If not defined, the default values are used
"""
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)

config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another test could be, what happens if a tokenizer is passed? Since tokenizer_name_or_path is by default set to model_name_or_path, then this value will always get passed when it's run in sft_trainer.py. I believe tokenizer is used in PromptTuning but would be ignored for LoRA?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only comment that I think wasn't addressed, what do you think about adding a test where you run with lora and pass in a tokenizer (which we expect to be ignored)?

assert isinstance(config, LoraConfig)
assert config.task_type == "CAUSAL_LM"
assert config.r == 3
assert config.lora_alpha == 3
assert (
config.lora_dropout == 0.05
) # default value from local peft_config.LoraConfig
assert config.target_modules == {
"q_proj",
"v_proj",
} # default value from local peft_config.LoraConfig
assert config.init_lora_weights is True # default value from HF peft.LoraConfig
assert (
config.megatron_core == "megatron.core"
) # default value from HF peft.LoraConfig


def test_get_hf_peft_config_ignores_tokenizer_path_for_lora_config():
"""Test that if tokenizer is given with a LoraConfig, it is ignored"""
tuning_config = peft_config.LoraConfig(r=3, lora_alpha=3)

config = config_utils.get_hf_peft_config(
task_type="CAUSAL_LM",
tuning_config=tuning_config,
tokenizer_name_or_path="foo/bar/path",
)
assert isinstance(config, LoraConfig)
assert config.task_type == "CAUSAL_LM"
assert config.r == 3
assert config.lora_alpha == 3
assert not hasattr(config, "tokenizer_name_or_path")


def test_get_hf_peft_config_returns_lora_config_with_correct_value_for_all_linear():
"""Test that when target_modules is ["all-linear"], we convert it to str type "all-linear" """
tuning_config = peft_config.LoraConfig(r=234, target_modules=["all-linear"])

config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "")
assert isinstance(config, LoraConfig)
assert config.target_modules == "all-linear"


def test_get_hf_peft_config_returns_pt_config_correctly():
"""Test that the prompt tuning config is set properly for each field
When a value is not defined, the default values are used
"""
tuning_config = peft_config.PromptTuningConfig(num_virtual_tokens=12)

config = config_utils.get_hf_peft_config("CAUSAL_LM", tuning_config, "foo/bar/path")
assert isinstance(config, PromptTuningConfig)
assert config.task_type == "CAUSAL_LM"
assert (
config.prompt_tuning_init == "TEXT"
) # default value from local peft_config.PromptTuningConfig
assert config.num_virtual_tokens == 12
assert (
config.prompt_tuning_init_text == "Classify if the tweet is a complaint or not:"
) # default value from local peft_config.PromptTuningConfig
assert config.tokenizer_name_or_path == "foo/bar/path"
assert config.num_layers is None # default value from HF peft.PromptTuningConfig
assert (
config.inference_mode is False
) # default value from HF peft.PromptTuningConfig


def test_get_hf_peft_config_returns_pt_config_with_correct_tokenizer_path():
"""Test that tokenizer path is allowed to be None only when prompt_tuning_init is not TEXT
Reference:
https://github.com/huggingface/peft/blob/main/src/peft/tuners/prompt_tuning/config.py#L73
"""

# When prompt_tuning_init is not TEXT, we can pass in None for tokenizer path
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="RANDOM")
config = config_utils.get_hf_peft_config(
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
)
assert isinstance(config, PromptTuningConfig)
assert config.tokenizer_name_or_path is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! So that means others like prompt_tuning_init_text which is also only used if prompt_tuning_init is TEXT would also become None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the function call to be more readable. The function call is

config = config_utils.get_hf_peft_config(task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None)

So the tokenizer_name_or_path is allowed to be None only when prompt_tuning_init is type RANDOM. When prompt_tuning_init is TEXT, you cannot pass None for tokenizer. It raises an exception. Reference

Other fields are not affected.


# When prompt_tuning_init is TEXT, exception is raised if tokenizer path is None
tuning_config = peft_config.PromptTuningConfig(prompt_tuning_init="TEXT")
with pytest.raises(ValueError) as err:
config_utils.get_hf_peft_config(
task_type=None, tuning_config=tuning_config, tokenizer_name_or_path=None
)
assert "tokenizer_name_or_path can't be None" in err.value


def test_create_tuning_config_for_peft_method_lora():
"""Test that LoraConfig is created for peft_method Lora
and fields are set properly.
If unknown fields are passed, they are ignored
"""
tune_config = config_utils.create_tuning_config("lora", foo="x", r=234)
assert isinstance(tune_config, peft_config.LoraConfig)
assert tune_config.r == 234
assert tune_config.lora_alpha == 32
assert tune_config.lora_dropout == 0.05
assert not hasattr(tune_config, "foo")


def test_create_tuning_config_for_peft_method_pt():
"""Test that PromptTuningConfig is created for peft_method pt
and fields are set properly
"""
tune_config = config_utils.create_tuning_config(
"pt", foo="x", prompt_tuning_init="RANDOM"
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)
assert tune_config.prompt_tuning_init == "RANDOM"


def test_create_tuning_config_for_peft_method_none():
"""Test that PromptTuningConfig is created for peft_method "None" or None"""
tune_config = config_utils.create_tuning_config("None")
assert tune_config is None

tune_config = config_utils.create_tuning_config(None)
assert tune_config is None


def test_create_tuning_config_does_not_recognize_any_other_peft_method():
"""Test that PromptTuningConfig is created for peft_method "None" or None,
"lora" or "pt", and no other
"""
with pytest.raises(AssertionError) as err:
config_utils.create_tuning_config("hello", foo="x")
assert err.value == "peft config hello not defined in peft.py"


def test_update_config_can_handle_dot_for_nested_field():
"""Test that the function can read dotted field for kwargs fields"""
config = peft_config.LoraConfig(r=5)
assert config.lora_alpha == 32 # default value is 32

# update lora_alpha to 98
kwargs = {"LoraConfig.lora_alpha": 98}
config_utils.update_config(config, **kwargs)
assert config.lora_alpha == 98


def test_update_config_does_nothing_for_unknown_field():
"""Test that the function does not change other config
field values if a kwarg field is unknown
"""
# foobar is an unknown field
config = peft_config.LoraConfig(r=5)
kwargs = {"LoraConfig.foobar": 98}
config_utils.update_config(config, **kwargs) # nothing happens
assert config.r == 5 # did not change r value
assert not hasattr(config, "foobar")


def test_update_config_can_handle_multiple_config_updates():
"""Test that the function can handle a tuple of configs"""
config = (peft_config.LoraConfig(r=5), peft_config.LoraConfig(r=7))
kwargs = {"r": 98}
config_utils.update_config(config, **kwargs)
assert config[0].r == 98
assert config[1].r == 98


def test_get_json_config_can_load_from_path():
"""Test that the function get_json_config can read
the json path from env var SFT_TRAINER_CONFIG_JSON_PATH
"""
if "SFT_TRAINER_CONFIG_JSON_ENV_VAR" in os.environ:
del os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"]
os.environ["SFT_TRAINER_CONFIG_JSON_PATH"] = HAPPY_PATH_DUMMY_CONFIG_PATH

job_config = config_utils.get_json_config()
assert job_config is not None
assert job_config["model_name_or_path"] == "bigscience/bloom-560m"


def test_get_json_config_can_load_from_envvar():
"""Test that the function get_json_config can read
the json path from env var SFT_TRAINER_CONFIG_JSON_ENV_VAR
"""
config_json = {"model_name_or_path": "foobar"}
message_bytes = pickle.dumps(config_json)
base64_bytes = base64.b64encode(message_bytes)
encoded_json = base64_bytes.decode("ascii")
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = encoded_json

job_config = config_utils.get_json_config()
assert job_config is not None
assert job_config["model_name_or_path"] == "foobar"
Loading