Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"tokenizers<=0.22",
"tqdm>=4.66.2,<5.0",
"trl>=0.19.1,<0.20.0",
"peft @ git+https://github.com/huggingface/peft.git@293aea5df6db240856a77f89955d1a89ce38b50d",
"peft @ git+https://github.com/romitjain/peft.git@0cb44e8bdd329201d0c68f1cb572f5871b4bbc7d",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this patch in PEFT? if not we can wait till the patch is in PEFT

Also fyi we have to tag this to a release or prerelease version and cannot tag it to a git commit as it will not allow us to make a pip release

Copy link
Contributor Author

@romitjain romitjain Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch is yet to be merged in PEFT. It is approved but I think it will take until this week to merge.

I don't think they will make a release this soon since they just did some time back (0.18.0). The previous version also points to git commit, though.

"datasets>=4.0.0,<5.0.0",
"simpleeval>=0.9.13,<2.0",
"pillow>=11.0.0,<12.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"num_hidden_layers": 8,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.30.2",
"use_cache": true,
Expand Down
174 changes: 151 additions & 23 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def test_run_causallm_lora_invalid_train_params(param_name, param_val, exc_msg):
setattr(invalid_params, param_name, param_val)

with pytest.raises(ValueError, match=exc_msg):
sft_trainer.train(MODEL_ARGS, DATA_ARGS, invalid_params, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, DATA_ARGS, invalid_params, copy.deepcopy(PEFT_LORA_ARGS))


@pytest.mark.parametrize(
Expand All @@ -649,7 +649,7 @@ def test_run_causallm_lora_with_validation(dataset_path):
data_args = copy.deepcopy(DATA_ARGS)
data_args.validation_data_path = dataset_path

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))
_validate_training(tempdir, check_eval=True)


Expand All @@ -670,7 +670,7 @@ def test_run_causallm_lora_with_validation_data_formatting(dataset_path):
"### Text: {{element['Tweet text']}} \n\n### Label: {{text_label}}"
)

sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))
_validate_training(tempdir, check_eval=True)


Expand Down Expand Up @@ -829,6 +829,134 @@ def test_successful_lora_target_modules_default_from_main():
"v_proj",
}, "target_modules are not set to the default values."

os.environ.pop("SFT_TRAINER_CONFIG_JSON_ENV_VAR", None)


def test_run_causallm_lora_add_special_tokens():
"""Check if embed layer is added as modules_to_save when special tokens are added"""
with tempfile.TemporaryDirectory() as tempdir:
# with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

base_lora_args = copy.deepcopy(PEFT_LORA_ARGS)
base_lora_args.target_modules = ["q_proj"]

# sample hugging face dataset id
data_args = copy.deepcopy(DATA_ARGS)
data_args.add_special_tokens = [
"<|test_token_1|>",
"<|test_token_2|>",
"<|test_token_3|>",
]

sft_trainer.train(MODEL_ARGS, data_args, train_args, base_lora_args)

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_path)

assert adapter_config.get("modules_to_save") is not None
assert "embed_tokens" in adapter_config.get("modules_to_save")

# Check if all special tokens passed are in tokenizer
for tok in data_args.add_special_tokens:
assert tok in tokenizer.vocab


modules_to_save_val_map = [
(None, []),
(["embed_tokens"], ["embed_tokens"]),
(["lm_head"], ["embed_tokens"]),
(["embed_tokens", "lm_head"], ["embed_tokens"]),
]


@pytest.mark.parametrize(
"modules_to_save, expected",
modules_to_save_val_map,
)
def test_run_causallm_lora_tied_weights_in_modules_to_save(modules_to_save, expected):
"""Check if a model with tied weights in modules to save is correctly trained"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

base_lora_args = copy.deepcopy(PEFT_LORA_ARGS)
base_lora_args.target_modules = ["q_proj"]
base_lora_args.modules_to_save = modules_to_save

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, base_lora_args)

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")

for module in expected:
assert module in adapter_config.get("modules_to_save")

# Load the model and merge it
loaded_model = TunedCausalLM.load(checkpoint_path, MAYKEYE_TINY_LLAMA_CACHED)
merged_model = loaded_model.peft_model.merge_and_unload()

# In all the cases Embedding and the LM layer should not have diverged
embed_layer = merged_model.get_input_embeddings()
lm_layer = merged_model.get_output_embeddings()

assert torch.allclose(embed_layer.weight, lm_layer.weight)
assert embed_layer.weight.data_ptr() == lm_layer.weight.data_ptr()


target_modules_tie_val_map = [
(["embed_tokens"], ["embed_tokens"]),
(["lm_head"], ["embed_tokens"]),
(["embed_tokens", "lm_head"], ["embed_tokens"]),
]


@pytest.mark.parametrize(
"target_modules, expected",
target_modules_tie_val_map,
)
def test_run_causallm_lora_tied_weights_in_target_modules(target_modules, expected):
"""Check if a model with tied weights in target_modules is correctly trained"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

base_lora_args = copy.deepcopy(PEFT_LORA_ARGS)
base_lora_args.target_modules = target_modules

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, base_lora_args)

# validate lora tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
adapter_config = _get_adapter_config(checkpoint_path)
_validate_adapter_config(adapter_config, "LORA")

for module in expected:
assert module in adapter_config.get("target_modules")

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MAYKEYE_TINY_LLAMA_CACHED)

# In all the cases Embedding and the LM layer should not have diverged
embed_layer = loaded_model.peft_model.get_input_embeddings()
lm_layer = loaded_model.peft_model.get_output_embeddings()
d_embed = embed_layer.get_delta_weight("default")
d_lm = lm_layer.get_delta_weight("default")

assert embed_layer.weight.data_ptr() == lm_layer.weight.data_ptr()
assert torch.allclose(
d_embed, d_lm, atol=1e-6
), f"Max diff between deltas: {(d_embed - d_lm).abs().max()}"


############################# Finetuning Tests #############################
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1816,7 +1944,7 @@ def test_tokenizer_has_no_eos_token():
# If we handled this badly, we would probably get something like a
# TypeError: can only concatenate str (not "NoneType") to str error
# when we go to apply the data formatter.
sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS)
sft_trainer.train(model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS))
_validate_training(tempdir)


Expand All @@ -1828,7 +1956,7 @@ def test_invalid_dataset_text_field():
data_args.dataset_text_field = "not found"

with pytest.raises(KeyError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS))


### Tests that giving dataset_text_field as well as formatter template gives error
Expand All @@ -1840,7 +1968,7 @@ def test_invalid_dataset_text_field_and_formatter_template():
)

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS))


### Tests passing formatter with invalid keys gives error
Expand All @@ -1852,7 +1980,7 @@ def test_invalid_formatter_template():
)

with pytest.raises(KeyError):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS))


### Tests for bad training data (i.e., data_path is an unhappy value or points to an unhappy thing)
Expand All @@ -1862,7 +1990,7 @@ def test_malformatted_data():
data_args.training_data_path = MALFORMATTED_DATA

with pytest.raises((DatasetGenerationError, ValueError)):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS))


def test_empty_data():
Expand All @@ -1871,7 +1999,7 @@ def test_empty_data():
data_args.training_data_path = EMPTY_DATA

with pytest.raises((DatasetGenerationError, ValueError)):
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, TRAIN_ARGS, copy.deepcopy(PEFT_LORA_ARGS))


### Tests for bad tuning module configurations
Expand Down Expand Up @@ -1900,7 +2028,7 @@ def test_no_packing_needs_dataset_text_field_or_data_formatter_template():
data_args.data_formatter_template = None

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))


# TODO: Fix this case
Expand All @@ -1914,7 +2042,7 @@ def test_no_packing_needs_reponse_template():
data_args.response_template = None

with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))


### Tests for model dtype edge cases
Expand All @@ -1931,7 +2059,7 @@ def test_bf16_still_tunes_if_unsupported():
model_args = copy.deepcopy(MODEL_ARGS)
model_args.torch_dtype = "bfloat16"

sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS)
sft_trainer.train(model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS))
_validate_training(tempdir)


Expand All @@ -1944,7 +2072,7 @@ def test_bad_torch_dtype():
model_args.torch_dtype = "not a type"

with pytest.raises(ValueError):
sft_trainer.train(model_args, DATA_ARGS, train_args, PEFT_LORA_ARGS)
sft_trainer.train(model_args, DATA_ARGS, train_args, copy.deepcopy(PEFT_LORA_ARGS))


def test_run_with_additional_callbacks():
Expand All @@ -1958,7 +2086,7 @@ def test_run_with_additional_callbacks():
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_callbacks=[TrainerCallback()],
)

Expand All @@ -1977,7 +2105,7 @@ def test_run_with_bad_additional_callbacks():
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_callbacks=["NotSupposedToBeHere"],
)

Expand All @@ -1998,7 +2126,7 @@ def test_run_with_bad_experimental_metadata():
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_callbacks=[TrainerCallback()],
exp_metadata=metadata,
)
Expand All @@ -2017,7 +2145,7 @@ def test_run_with_good_experimental_metadata():
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_callbacks=[TrainerCallback()],
exp_metadata=metadata,
)
Expand All @@ -2040,7 +2168,7 @@ def test_pretokenized_dataset(dataset_path):
data_args.dataset_text_field = None
data_args.response_template = None
data_args.training_data_path = dataset_path
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))
_validate_training(tempdir)


Expand All @@ -2064,7 +2192,7 @@ def test_pretokenized_dataset_bad_args(dataset_text_field, response_template):
# We should raise an error since we should not have a dataset text
# field or a response template if we have pretokenized data
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))


def test_pretokenized_dataset_wrong_format():
Expand All @@ -2082,7 +2210,7 @@ def test_pretokenized_dataset_wrong_format():
# need to directly add validation prior to the dataset generation since datasets
# is essentially swallowing a KeyError here.
with pytest.raises(ValueError):
sft_trainer.train(MODEL_ARGS, data_args, train_args, PEFT_LORA_ARGS)
sft_trainer.train(MODEL_ARGS, data_args, train_args, copy.deepcopy(PEFT_LORA_ARGS))


###########################################################################
Expand Down Expand Up @@ -2115,7 +2243,7 @@ def test_run_with_bad_additional_data_handlers(additional_handlers):
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_data_handlers=additional_handlers,
)

Expand All @@ -2130,7 +2258,7 @@ def test_run_with_additional_data_handlers_as_none():
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_data_handlers=None,
)
_validate_training(tempdir)
Expand Down Expand Up @@ -2177,7 +2305,7 @@ def test_handler(element, **kwargs):
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_LORA_ARGS,
copy.deepcopy(PEFT_LORA_ARGS),
additional_data_handlers={
TEST_HANDLER: DataHandler(
op=test_handler,
Expand Down
27 changes: 27 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,33 @@ def train(

added_tokens_dict = setup_tokenizer(tokenizer, data_args, model_args, model)

# If additional tokens are added, and we are doing LoRA
# we need to set the embedding layer as trainable
# and ensure that the weights are tied
# See https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/1673
# for more details
if added_tokens_dict and isinstance(peft_config, LoraConfig):
if added_tokens_dict.get("num_new_tokens", 0) > 0:
logger.info(
"Adding embed_tokens and lm_head as trainable modules due to vocab expansion"
)
modules_to_save = getattr(peft_config, "modules_to_save", []) or []
target_modules = getattr(peft_config, "target_modules", []) or []

# If the initial model's weights are not tied,
# then we need to add both the embedding layer and the output layer
# If embedding layer or lm head is already targetted via `target_modules`
# then we skip adding it `modules_to_save` since it is already adapted
# for changes
if not any(m in target_modules for m in ("embed_tokens", "lm_head")):
modules_to_save.extend(["embed_tokens", "lm_head"])
setattr(peft_config, "modules_to_save", modules_to_save)

# This is safe to do for both tied and non-tied models
# `ensure_weight_tying` will be ignored if weights are not tied
# https://github.com/huggingface/peft/blob/v0.18.0.rc0/src/peft/tuners/tuners_utils.py#L1230
setattr(peft_config, "ensure_weight_tying", True)

# Configure the collator and validate args related to packing prior to formatting the dataset
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)
Expand Down
Loading
Loading