Skip to content
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"tokenizers<=0.22",
"tqdm>=4.66.2,<5.0",
"trl>=0.19.1,<0.20.0",
"peft @ git+https://github.com/huggingface/peft.git@293aea5df6db240856a77f89955d1a89ce38b50d",
"peft @ git+https://github.com/huggingface/peft.git@d0fa97413f9b642e5bb8ea4d290440b4316393b2",
"datasets>=4.0.0,<5.0.0",
"simpleeval>=0.9.13,<2.0",
"pillow>=11.0.0,<12.0",
Expand Down
16 changes: 16 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,22 @@ def train(

added_tokens_dict = setup_tokenizer(tokenizer, data_args, model_args, model)

# If additional tokens are added, and we are doing LoRA
# we need to set the embedding layer as trainable
# and ensure that the weights are tied
# See https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/1673
# for more details
if added_tokens_dict and isinstance(peft_config, LoraConfig):
if getattr(added_tokens_dict, "num_new_tokens", 0) > 0:
modules_to_save = (getattr(peft_config, "modules_to_save", []) or [])
# If the initial model's weights are not tied,
# then we need to add both the embedding layer and the output layer
modules_to_save.extend(["embed_tokens", "lm_head"])
setattr(peft_config, "modules_to_save", modules_to_save)
# `ensure_weight_tying` will be ignored if weights are not tied
# https://github.com/huggingface/peft/blob/v0.18.0.rc0/src/peft/tuners/tuners_utils.py#L1230
setattr(peft_config, "ensure_weight_tying", True)

# Configure the collator and validate args related to packing prior to formatting the dataset
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)
Expand Down
6 changes: 6 additions & 0 deletions tuning/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def get_hf_peft_config(task_type, tuning_config, tokenizer_name_or_path):

if hasattr(tuning_config, "alora_invocation_string"):
delattr(tuning_config, "alora_invocation_string")

# Make sure that weight tying is not broken in case
# the embedding layer is added as trainable under LoRA
if "embed_tokens" or "lm_head" in getattr(tuning_config, "modules_to_save", []) or []:
setattr(tuning_config, "ensure_weight_tying", True)

return tuning_config

if isinstance(tuning_config, peft_config.PromptTuningConfig):
Expand Down
Loading