Skip to content

Commit 057558f

Browse files
authored
Added evaluation script for qualcomm LlamaModel
Differential Revision: D76634688 Pull Request resolved: #11663
1 parent a23ed73 commit 057558f

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ python_binary(
4545
],
4646
)
4747

48+
python_binary(
49+
name = "eval_llama_qnn",
50+
srcs = ["eval_llama_qnn.py"],
51+
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
52+
deps = [
53+
":llama_lib",
54+
"//executorch/examples/models/llama:eval_library",
55+
"fbsource//third-party/pypi/lm-eval:lm-eval",
56+
],
57+
)
58+
4859
runtime.command_alias(
4960
name = "llama_qnn",
5061
env = {
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import logging
9+
import copy
10+
import json
11+
import torch
12+
from lm_eval.evaluator import simple_evaluate
13+
14+
from typing import List, Optional, Tuple
15+
16+
import torch
17+
import torch.nn as nn
18+
19+
from executorch.examples.models.llama.eval_llama_lib import (
20+
build_args_parser,
21+
GraphModuleEvalWrapper
22+
)
23+
24+
from pytorch_tokenizers import get_tokenizer
25+
26+
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
27+
LlamaModel,
28+
ModelArgs,
29+
)
30+
31+
32+
class WrappedLlamaModel(nn.Module):
33+
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device='cuda'):
34+
super(WrappedLlamaModel, self).__init__()
35+
self.model = model
36+
self.max_seq_len = max_seq_len
37+
self.use_kv_cache = use_kv_cache
38+
self.device = device
39+
40+
def forward(self,
41+
tokens: torch.Tensor,
42+
input_pos: Optional[torch.Tensor] = None,
43+
*args,
44+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
45+
# Pad input if necessary, since LlamaModel requires static shape
46+
if tokens.shape[1] != self.max_seq_len:
47+
tokens = torch.nn.functional.pad(tokens, (self.max_seq_len - tokens.shape[1],0))
48+
atten_mask = self.model.get_example_inputs(self.use_kv_cache)[1].to(device=self.device).to(dtype=torch.bfloat16)
49+
return self.model.forward(tokens, atten_mask, input_pos, *args)
50+
51+
52+
53+
def gen_eval_wrapper(model_name, args):
54+
tokenizer = get_tokenizer(args.tokenizer_path)
55+
with open(args.params) as f:
56+
kv_config = ModelArgs(**json.load(f))
57+
# TODO: support batch inputs if necessary
58+
kv_config.max_batch_size = 1
59+
kv_config.max_seq_len = args.max_seq_length
60+
kv_config.use_kv_cache = True
61+
62+
prefill_config = copy.copy(kv_config)
63+
prefill_config.max_seq_len = args.max_seq_length
64+
prefill_config.use_kv_cache = (
65+
False if args.max_seq_length == args.prefill_ar_len else True
66+
)
67+
config = prefill_config
68+
use_i64_token = args.embedding_quantize is not None
69+
model = LlamaModel(config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, use_i64_token=use_i64_token)
70+
state_dict = torch.load(
71+
args.checkpoint, weights_only=True, map_location=args.device, mmap=True
72+
)
73+
74+
# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
75+
def permute(w, heads):
76+
dim_0 = w.size(0)
77+
dim_1 = w.size(1)
78+
return (
79+
w.view(heads, dim_0 // heads // 2, 2, dim_1)
80+
.transpose(1, 2)
81+
.reshape(dim_0, dim_1)
82+
)
83+
84+
n_heads = model.n_heads
85+
n_kv_heads = model.n_kv_heads
86+
n_layers = model.n_layers
87+
88+
for layer_i in range(n_layers):
89+
state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute(
90+
state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads
91+
)
92+
state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute(
93+
state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads
94+
)
95+
96+
model.load_state_dict(
97+
state_dict,
98+
strict=True,
99+
assign=True,
100+
)
101+
102+
if "model" in state_dict:
103+
state_dict = state_dict["model"]
104+
105+
for layer in model.layers:
106+
if getattr(layer.attention, "prepare_sha", None):
107+
layer.attention.prepare_sha()
108+
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
109+
layer.feed_forward.prepare_feedfoward_conv()
110+
111+
model.to(dtype=torch.bfloat16)
112+
model.to(args.device)
113+
114+
wrapped_model = WrappedLlamaModel(model, args.use_kv_cache, args.max_seq_length, args.device)
115+
116+
return GraphModuleEvalWrapper(
117+
model=wrapped_model,
118+
tokenizer=tokenizer,
119+
max_seq_length=args.calibration_seq_length,
120+
use_kv_cache=args.use_kv_cache,
121+
generate_full_logits=args.generate_full_logits,
122+
enable_dynamic_shape=args.enable_dynamic_shape,
123+
)
124+
125+
126+
127+
def eval_llama(
128+
model_name: str,
129+
args: argparse.Namespace,
130+
) -> None:
131+
# Generate the eval wrapper
132+
eval_wrapper = gen_eval_wrapper(model_name, args)
133+
134+
# Needed for loading mmlu dataset.
135+
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
136+
if args.tasks and "mmlu" in args.tasks:
137+
import datasets
138+
139+
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
140+
# Evaluate the model
141+
with torch.no_grad():
142+
eval_results = simple_evaluate(
143+
model=eval_wrapper,
144+
tasks=args.tasks,
145+
num_fewshot=args.num_fewshot,
146+
limit=args.limit,
147+
)
148+
149+
for task, res in eval_results["results"].items():
150+
print(f"{task}: {res}")
151+
152+
153+
def main() -> None:
154+
seed = 42
155+
torch.manual_seed(seed)
156+
modelname = "llama2"
157+
parser = build_args_parser()
158+
args = parser.parse_args()
159+
# Overrides this arg, because evaluation requires full logits.
160+
args.generate_full_logits = True
161+
162+
args.max_seq_len = args.max_seq_length
163+
args.calibration_seq_length = args.max_seq_length
164+
165+
# Prefill mode
166+
args.use_kv_cache = False
167+
args.prefill_ar_len = args.max_seq_length
168+
169+
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
170+
171+
eval_llama(modelname, args)
172+
173+
174+
if __name__ == "__main__":
175+
main() # pragma: no cover

0 commit comments

Comments
 (0)