5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import argparse
8
- import logging
9
8
import copy
10
9
import json
11
- import torch
12
- from lm_eval .evaluator import simple_evaluate
13
10
14
11
from typing import List , Optional , Tuple
15
12
18
15
19
16
from executorch .examples .models .llama .eval_llama_lib import (
20
17
build_args_parser ,
21
- GraphModuleEvalWrapper
18
+ GraphModuleEvalWrapper ,
22
19
)
23
20
24
- from pytorch_tokenizers import get_tokenizer
25
-
26
21
from executorch .examples .qualcomm .oss_scripts .llama .model .static_llama import (
27
- LlamaModel ,
28
- ModelArgs ,
22
+ LlamaModel ,
23
+ ModelArgs ,
29
24
)
25
+ from lm_eval .evaluator import simple_evaluate
26
+
27
+ from pytorch_tokenizers import get_tokenizer
30
28
31
29
32
30
class WrappedLlamaModel (nn .Module ):
33
- def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = ' cuda' ):
31
+ def __init__ (self , model , use_kv_cache = False , max_seq_len = 512 , device = " cuda" ):
34
32
super (WrappedLlamaModel , self ).__init__ ()
35
33
self .model = model
36
34
self .max_seq_len = max_seq_len
37
35
self .use_kv_cache = use_kv_cache
38
36
self .device = device
39
37
40
- def forward (self ,
38
+ def forward (
39
+ self ,
41
40
tokens : torch .Tensor ,
42
41
input_pos : Optional [torch .Tensor ] = None ,
43
42
* args ,
44
43
) -> Tuple [torch .Tensor , List [torch .Tensor ], List [torch .Tensor ]]:
45
44
# Pad input if necessary, since LlamaModel requires static shape
46
45
if tokens .shape [1 ] != self .max_seq_len :
47
- tokens = torch .nn .functional .pad (tokens , (self .max_seq_len - tokens .shape [1 ],0 ))
48
- atten_mask = self .model .get_example_inputs (self .use_kv_cache )[1 ].to (device = self .device ).to (dtype = torch .bfloat16 )
46
+ tokens = torch .nn .functional .pad (
47
+ tokens , (self .max_seq_len - tokens .shape [1 ], 0 )
48
+ )
49
+ atten_mask = (
50
+ self .model .get_example_inputs (self .use_kv_cache )[1 ]
51
+ .to (device = self .device )
52
+ .to (dtype = torch .bfloat16 )
53
+ )
49
54
return self .model .forward (tokens , atten_mask , input_pos , * args )
50
55
51
56
52
-
53
57
def gen_eval_wrapper (model_name , args ):
54
58
tokenizer = get_tokenizer (args .tokenizer_path )
55
59
with open (args .params ) as f :
@@ -66,7 +70,13 @@ def gen_eval_wrapper(model_name, args):
66
70
)
67
71
config = prefill_config
68
72
use_i64_token = args .embedding_quantize is not None
69
- model = LlamaModel (config , ar_len = args .prefill_ar_len , output_new_cache_only = True , output_cache = False , use_i64_token = use_i64_token )
73
+ model = LlamaModel (
74
+ config ,
75
+ ar_len = args .prefill_ar_len ,
76
+ output_new_cache_only = True ,
77
+ output_cache = False ,
78
+ use_i64_token = use_i64_token ,
79
+ )
70
80
state_dict = torch .load (
71
81
args .checkpoint , weights_only = True , map_location = args .device , mmap = True
72
82
)
@@ -111,7 +121,9 @@ def permute(w, heads):
111
121
model .to (dtype = torch .bfloat16 )
112
122
model .to (args .device )
113
123
114
- wrapped_model = WrappedLlamaModel (model , args .use_kv_cache , args .max_seq_length , args .device )
124
+ wrapped_model = WrappedLlamaModel (
125
+ model , args .use_kv_cache , args .max_seq_length , args .device
126
+ )
115
127
116
128
return GraphModuleEvalWrapper (
117
129
model = wrapped_model ,
@@ -123,7 +135,6 @@ def permute(w, heads):
123
135
)
124
136
125
137
126
-
127
138
def eval_llama (
128
139
model_name : str ,
129
140
args : argparse .Namespace ,
@@ -166,7 +177,7 @@ def main() -> None:
166
177
args .use_kv_cache = False
167
178
args .prefill_ar_len = args .max_seq_length
168
179
169
- args .device = ' cuda' if torch .cuda .is_available () else ' cpu'
180
+ args .device = " cuda" if torch .cuda .is_available () else " cpu"
170
181
171
182
eval_llama (modelname , args )
172
183
0 commit comments