Skip to content

Commit 537215f

Browse files
Ssukritianhuong
andauthored
fix: remove lm_head for granite with llama arch models (#258)
* initial code for deleting lm_head Signed-off-by: Anh-Uong <anh.uong@ibm.com> * fix logic for copying checkpoint Signed-off-by: Anh-Uong <anh.uong@ibm.com> * fix check that embed_tokens and lm_head weights are the same Signed-off-by: Anh-Uong <anh.uong@ibm.com> * fix warning assertion Signed-off-by: Anh-Uong <anh.uong@ibm.com> * fix lm_head check, remove test Signed-off-by: Anh-Uong <anh.uong@ibm.com> * small fixes from code review Signed-off-by: Anh-Uong <anh.uong@ibm.com> * fmt Signed-off-by: Anh-Uong <anh.uong@ibm.com> --------- Signed-off-by: Anh-Uong <anh.uong@ibm.com> Co-authored-by: Anh-Uong <anh.uong@ibm.com>
1 parent 55ca612 commit 537215f

File tree

1 file changed

+94
-9
lines changed

1 file changed

+94
-9
lines changed

build/accelerate_launch.py

+94-9
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@
2626
import tempfile
2727
import shutil
2828
from pathlib import Path
29+
import json
2930

3031
# Third Party
3132
from accelerate.commands.launch import launch_command
33+
from transformers import AutoModelForCausalLM, AutoTokenizer
34+
from peft import PeftModel
35+
from torch import bfloat16
3236

3337
# Local
3438
from build.utils import (
@@ -44,10 +48,18 @@
4448
USER_ERROR_EXIT_CODE,
4549
INTERNAL_ERROR_EXIT_CODE,
4650
)
51+
from tuning.data import tokenizer_data_utils
4752

4853
ERROR_LOG = "/dev/termination-log"
4954

5055

56+
def get_base_model_from_adapter_config(adapter_config):
57+
"""Given path to adapter_config.json file, returns the base model name"""
58+
with open(adapter_config, "r", encoding="utf-8") as config_file:
59+
adapter_config = json.load(config_file)
60+
return adapter_config.get("base_model_name_or_path")
61+
62+
5163
def main():
5264
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
5365
logging.basicConfig(level=LOGLEVEL)
@@ -118,16 +130,89 @@ def main():
118130
sys.exit(INTERNAL_ERROR_EXIT_CODE)
119131

120132
try:
121-
# copy last checkpoint into mounted output dir
122-
pt_checkpoint_dir = get_highest_checkpoint(tempdir)
123-
logging.info(
124-
"Copying last checkpoint %s into output dir %s",
125-
pt_checkpoint_dir,
126-
original_output_dir,
127-
)
128-
copy_checkpoint(
129-
os.path.join(tempdir, pt_checkpoint_dir), original_output_dir
133+
last_checkpoint_dir = get_highest_checkpoint(tempdir)
134+
last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir)
135+
136+
use_flash_attn = job_config.get("use_flash_attn", True)
137+
adapter_config_path = os.path.join(
138+
last_checkpoint_path, "adapter_config.json"
130139
)
140+
tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path)
141+
142+
if os.path.exists(adapter_config_path):
143+
base_model_path = get_base_model_from_adapter_config(
144+
adapter_config_path
145+
)
146+
base_model = AutoModelForCausalLM.from_pretrained(
147+
base_model_path,
148+
attn_implementation="flash_attention_2" if use_flash_attn else None,
149+
torch_dtype=bfloat16 if use_flash_attn else None,
150+
)
151+
152+
# since the peft library (PEFTModelForCausalLM) does not handle cases
153+
# where the model's layers are modified, in our case the embedding layer
154+
# is modified, so we resize the backbone model's embedding layer with our own
155+
# utility before passing it along to load the PEFT model.
156+
tokenizer_data_utils.tokenizer_and_embedding_resize(
157+
{}, tokenizer=tokenizer, model=base_model
158+
)
159+
model = PeftModel.from_pretrained(
160+
base_model,
161+
last_checkpoint_path,
162+
attn_implementation="flash_attention_2" if use_flash_attn else None,
163+
torch_dtype=bfloat16 if use_flash_attn else None,
164+
)
165+
else:
166+
model = AutoModelForCausalLM.from_pretrained(
167+
last_checkpoint_path,
168+
attn_implementation="flash_attention_2" if use_flash_attn else None,
169+
torch_dtype=bfloat16 if use_flash_attn else None,
170+
)
171+
172+
model_arch = model.config.model_type
173+
# check that it is a granite model with llama architecture with tied weights
174+
# ie. lm_head is duplicate of embeddings
175+
176+
# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
177+
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
178+
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
179+
copy_checkpoint_bool = True
180+
if model_arch == "llama" and hasattr(model, "lm_head"):
181+
if (
182+
# lora tuned model has an addt model layer
183+
(
184+
hasattr(model.model, "model")
185+
and model.lm_head.weight.untyped_storage().data_ptr()
186+
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
187+
)
188+
# prompt tuned model or fine tuned model
189+
or (
190+
hasattr(model.model, "embed_tokens")
191+
and model.lm_head.weight.untyped_storage().data_ptr()
192+
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
193+
)
194+
):
195+
196+
copy_checkpoint_bool = False
197+
logging.info("Removing lm_head from checkpoint")
198+
del model.lm_head.weight
199+
200+
if hasattr(model, "lm_head.weight"):
201+
logging.warning("Failed to delete lm_head.weight from model")
202+
203+
logging.info("Saving checkpoint to %s", original_output_dir)
204+
model.save_pretrained(original_output_dir)
205+
# save tokenizer with model
206+
tokenizer.save_pretrained(original_output_dir)
207+
208+
# copy last checkpoint into mounted output dir
209+
if copy_checkpoint_bool:
210+
logging.info(
211+
"Copying last checkpoint %s into output dir %s",
212+
last_checkpoint_dir,
213+
original_output_dir,
214+
)
215+
copy_checkpoint(last_checkpoint_path, original_output_dir)
131216
except Exception as e: # pylint: disable=broad-except
132217
logging.error(traceback.format_exc())
133218
write_termination_log(

0 commit comments

Comments
 (0)