|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import argparse |
| 8 | +import logging |
| 9 | +import copy |
| 10 | +import json |
| 11 | +import torch |
| 12 | +from lm_eval.evaluator import simple_evaluate |
| 13 | + |
| 14 | +from typing import List, Optional, Tuple |
| 15 | + |
| 16 | +import torch |
| 17 | +import torch.nn as nn |
| 18 | + |
| 19 | +from executorch.examples.models.llama.eval_llama_lib import ( |
| 20 | + build_args_parser, |
| 21 | + GraphModuleEvalWrapper |
| 22 | +) |
| 23 | + |
| 24 | +from pytorch_tokenizers import get_tokenizer |
| 25 | + |
| 26 | +from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( |
| 27 | + LlamaModel, |
| 28 | + ModelArgs, |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +class WrappedLlamaModel(nn.Module): |
| 33 | + def __init__(self, model, use_kv_cache=False, max_seq_len=512, device='cuda'): |
| 34 | + super(WrappedLlamaModel, self).__init__() |
| 35 | + self.model = model |
| 36 | + self.max_seq_len = max_seq_len |
| 37 | + self.use_kv_cache = use_kv_cache |
| 38 | + self.device = device |
| 39 | + |
| 40 | + def forward(self, |
| 41 | + tokens: torch.Tensor, |
| 42 | + input_pos: Optional[torch.Tensor] = None, |
| 43 | + *args, |
| 44 | + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: |
| 45 | + # Pad input if necessary, since LlamaModel requires static shape |
| 46 | + if tokens.shape[1] != self.max_seq_len: |
| 47 | + tokens = torch.nn.functional.pad(tokens, (self.max_seq_len - tokens.shape[1],0)) |
| 48 | + atten_mask = self.model.get_example_inputs(self.use_kv_cache)[1].to(device=self.device).to(dtype=torch.bfloat16) |
| 49 | + return self.model.forward(tokens, atten_mask, input_pos, *args) |
| 50 | + |
| 51 | + |
| 52 | + |
| 53 | +def gen_eval_wrapper(model_name, args): |
| 54 | + tokenizer = get_tokenizer(args.tokenizer_path) |
| 55 | + with open(args.params) as f: |
| 56 | + kv_config = ModelArgs(**json.load(f)) |
| 57 | + # TODO: support batch inputs if necessary |
| 58 | + kv_config.max_batch_size = 1 |
| 59 | + kv_config.max_seq_len = args.max_seq_length |
| 60 | + kv_config.use_kv_cache = True |
| 61 | + |
| 62 | + prefill_config = copy.copy(kv_config) |
| 63 | + prefill_config.max_seq_len = args.max_seq_length |
| 64 | + prefill_config.use_kv_cache = ( |
| 65 | + False if args.max_seq_length == args.prefill_ar_len else True |
| 66 | + ) |
| 67 | + config = prefill_config |
| 68 | + use_i64_token = args.embedding_quantize is not None |
| 69 | + model = LlamaModel(config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, use_i64_token=use_i64_token) |
| 70 | + state_dict = torch.load( |
| 71 | + args.checkpoint, weights_only=True, map_location=args.device, mmap=True |
| 72 | + ) |
| 73 | + |
| 74 | + # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. |
| 75 | + def permute(w, heads): |
| 76 | + dim_0 = w.size(0) |
| 77 | + dim_1 = w.size(1) |
| 78 | + return ( |
| 79 | + w.view(heads, dim_0 // heads // 2, 2, dim_1) |
| 80 | + .transpose(1, 2) |
| 81 | + .reshape(dim_0, dim_1) |
| 82 | + ) |
| 83 | + |
| 84 | + n_heads = model.n_heads |
| 85 | + n_kv_heads = model.n_kv_heads |
| 86 | + n_layers = model.n_layers |
| 87 | + |
| 88 | + for layer_i in range(n_layers): |
| 89 | + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( |
| 90 | + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads |
| 91 | + ) |
| 92 | + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( |
| 93 | + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads |
| 94 | + ) |
| 95 | + |
| 96 | + model.load_state_dict( |
| 97 | + state_dict, |
| 98 | + strict=True, |
| 99 | + assign=True, |
| 100 | + ) |
| 101 | + |
| 102 | + if "model" in state_dict: |
| 103 | + state_dict = state_dict["model"] |
| 104 | + |
| 105 | + for layer in model.layers: |
| 106 | + if getattr(layer.attention, "prepare_sha", None): |
| 107 | + layer.attention.prepare_sha() |
| 108 | + if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): |
| 109 | + layer.feed_forward.prepare_feedfoward_conv() |
| 110 | + |
| 111 | + model.to(dtype=torch.bfloat16) |
| 112 | + model.to(args.device) |
| 113 | + |
| 114 | + wrapped_model = WrappedLlamaModel(model, args.use_kv_cache, args.max_seq_length, args.device) |
| 115 | + |
| 116 | + return GraphModuleEvalWrapper( |
| 117 | + model=wrapped_model, |
| 118 | + tokenizer=tokenizer, |
| 119 | + max_seq_length=args.calibration_seq_length, |
| 120 | + use_kv_cache=args.use_kv_cache, |
| 121 | + generate_full_logits=args.generate_full_logits, |
| 122 | + enable_dynamic_shape=args.enable_dynamic_shape, |
| 123 | + ) |
| 124 | + |
| 125 | + |
| 126 | + |
| 127 | +def eval_llama( |
| 128 | + model_name: str, |
| 129 | + args: argparse.Namespace, |
| 130 | +) -> None: |
| 131 | + # Generate the eval wrapper |
| 132 | + eval_wrapper = gen_eval_wrapper(model_name, args) |
| 133 | + |
| 134 | + # Needed for loading mmlu dataset. |
| 135 | + # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files |
| 136 | + if args.tasks and "mmlu" in args.tasks: |
| 137 | + import datasets |
| 138 | + |
| 139 | + datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True |
| 140 | + # Evaluate the model |
| 141 | + with torch.no_grad(): |
| 142 | + eval_results = simple_evaluate( |
| 143 | + model=eval_wrapper, |
| 144 | + tasks=args.tasks, |
| 145 | + num_fewshot=args.num_fewshot, |
| 146 | + limit=args.limit, |
| 147 | + ) |
| 148 | + |
| 149 | + for task, res in eval_results["results"].items(): |
| 150 | + print(f"{task}: {res}") |
| 151 | + |
| 152 | + |
| 153 | +def main() -> None: |
| 154 | + seed = 42 |
| 155 | + torch.manual_seed(seed) |
| 156 | + modelname = "llama2" |
| 157 | + parser = build_args_parser() |
| 158 | + args = parser.parse_args() |
| 159 | + # Overrides this arg, because evaluation requires full logits. |
| 160 | + args.generate_full_logits = True |
| 161 | + |
| 162 | + args.max_seq_len = args.max_seq_length |
| 163 | + args.calibration_seq_length = args.max_seq_length |
| 164 | + |
| 165 | + # Prefill mode |
| 166 | + args.use_kv_cache = False |
| 167 | + args.prefill_ar_len = args.max_seq_length |
| 168 | + |
| 169 | + args.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 170 | + |
| 171 | + eval_llama(modelname, args) |
| 172 | + |
| 173 | + |
| 174 | +if __name__ == "__main__": |
| 175 | + main() # pragma: no cover |
0 commit comments