|
26 | 26 | import tempfile
|
27 | 27 | import shutil
|
28 | 28 | from pathlib import Path
|
| 29 | +import json |
29 | 30 |
|
30 | 31 | # Third Party
|
31 | 32 | from accelerate.commands.launch import launch_command
|
| 33 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 34 | +from peft import PeftModel |
| 35 | +from torch import bfloat16 |
32 | 36 |
|
33 | 37 | # Local
|
34 | 38 | from build.utils import (
|
|
44 | 48 | USER_ERROR_EXIT_CODE,
|
45 | 49 | INTERNAL_ERROR_EXIT_CODE,
|
46 | 50 | )
|
| 51 | +from tuning.data import tokenizer_data_utils |
47 | 52 |
|
48 | 53 | ERROR_LOG = "/dev/termination-log"
|
49 | 54 |
|
50 | 55 |
|
| 56 | +def get_base_model_from_adapter_config(adapter_config): |
| 57 | + """Given path to adapter_config.json file, returns the base model name""" |
| 58 | + with open(adapter_config, "r", encoding="utf-8") as config_file: |
| 59 | + adapter_config = json.load(config_file) |
| 60 | + return adapter_config.get("base_model_name_or_path") |
| 61 | + |
| 62 | + |
51 | 63 | def main():
|
52 | 64 | LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
|
53 | 65 | logging.basicConfig(level=LOGLEVEL)
|
@@ -118,16 +130,89 @@ def main():
|
118 | 130 | sys.exit(INTERNAL_ERROR_EXIT_CODE)
|
119 | 131 |
|
120 | 132 | try:
|
121 |
| - # copy last checkpoint into mounted output dir |
122 |
| - pt_checkpoint_dir = get_highest_checkpoint(tempdir) |
123 |
| - logging.info( |
124 |
| - "Copying last checkpoint %s into output dir %s", |
125 |
| - pt_checkpoint_dir, |
126 |
| - original_output_dir, |
127 |
| - ) |
128 |
| - copy_checkpoint( |
129 |
| - os.path.join(tempdir, pt_checkpoint_dir), original_output_dir |
| 133 | + last_checkpoint_dir = get_highest_checkpoint(tempdir) |
| 134 | + last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir) |
| 135 | + |
| 136 | + use_flash_attn = job_config.get("use_flash_attn", True) |
| 137 | + adapter_config_path = os.path.join( |
| 138 | + last_checkpoint_path, "adapter_config.json" |
130 | 139 | )
|
| 140 | + tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path) |
| 141 | + |
| 142 | + if os.path.exists(adapter_config_path): |
| 143 | + base_model_path = get_base_model_from_adapter_config( |
| 144 | + adapter_config_path |
| 145 | + ) |
| 146 | + base_model = AutoModelForCausalLM.from_pretrained( |
| 147 | + base_model_path, |
| 148 | + attn_implementation="flash_attention_2" if use_flash_attn else None, |
| 149 | + torch_dtype=bfloat16 if use_flash_attn else None, |
| 150 | + ) |
| 151 | + |
| 152 | + # since the peft library (PEFTModelForCausalLM) does not handle cases |
| 153 | + # where the model's layers are modified, in our case the embedding layer |
| 154 | + # is modified, so we resize the backbone model's embedding layer with our own |
| 155 | + # utility before passing it along to load the PEFT model. |
| 156 | + tokenizer_data_utils.tokenizer_and_embedding_resize( |
| 157 | + {}, tokenizer=tokenizer, model=base_model |
| 158 | + ) |
| 159 | + model = PeftModel.from_pretrained( |
| 160 | + base_model, |
| 161 | + last_checkpoint_path, |
| 162 | + attn_implementation="flash_attention_2" if use_flash_attn else None, |
| 163 | + torch_dtype=bfloat16 if use_flash_attn else None, |
| 164 | + ) |
| 165 | + else: |
| 166 | + model = AutoModelForCausalLM.from_pretrained( |
| 167 | + last_checkpoint_path, |
| 168 | + attn_implementation="flash_attention_2" if use_flash_attn else None, |
| 169 | + torch_dtype=bfloat16 if use_flash_attn else None, |
| 170 | + ) |
| 171 | + |
| 172 | + model_arch = model.config.model_type |
| 173 | + # check that it is a granite model with llama architecture with tied weights |
| 174 | + # ie. lm_head is duplicate of embeddings |
| 175 | + |
| 176 | + # a fine tuned model will have params_dict.get("model.embed_tokens.weight") |
| 177 | + # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") |
| 178 | + # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") |
| 179 | + copy_checkpoint_bool = True |
| 180 | + if model_arch == "llama" and hasattr(model, "lm_head"): |
| 181 | + if ( |
| 182 | + # lora tuned model has an addt model layer |
| 183 | + ( |
| 184 | + hasattr(model.model, "model") |
| 185 | + and model.lm_head.weight.untyped_storage().data_ptr() |
| 186 | + == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() |
| 187 | + ) |
| 188 | + # prompt tuned model or fine tuned model |
| 189 | + or ( |
| 190 | + hasattr(model.model, "embed_tokens") |
| 191 | + and model.lm_head.weight.untyped_storage().data_ptr() |
| 192 | + == model.model.embed_tokens.weight.untyped_storage().data_ptr() |
| 193 | + ) |
| 194 | + ): |
| 195 | + |
| 196 | + copy_checkpoint_bool = False |
| 197 | + logging.info("Removing lm_head from checkpoint") |
| 198 | + del model.lm_head.weight |
| 199 | + |
| 200 | + if hasattr(model, "lm_head.weight"): |
| 201 | + logging.warning("Failed to delete lm_head.weight from model") |
| 202 | + |
| 203 | + logging.info("Saving checkpoint to %s", original_output_dir) |
| 204 | + model.save_pretrained(original_output_dir) |
| 205 | + # save tokenizer with model |
| 206 | + tokenizer.save_pretrained(original_output_dir) |
| 207 | + |
| 208 | + # copy last checkpoint into mounted output dir |
| 209 | + if copy_checkpoint_bool: |
| 210 | + logging.info( |
| 211 | + "Copying last checkpoint %s into output dir %s", |
| 212 | + last_checkpoint_dir, |
| 213 | + original_output_dir, |
| 214 | + ) |
| 215 | + copy_checkpoint(last_checkpoint_path, original_output_dir) |
131 | 216 | except Exception as e: # pylint: disable=broad-except
|
132 | 217 | logging.error(traceback.format_exc())
|
133 | 218 | write_termination_log(
|
|
0 commit comments