Skip to content

Commit 7b39a0c

Browse files
guangy10Guang Yang
andauthored
Fixed linter (#11742)
### Summary Fixed linter error. ### Test plan CI Co-authored-by: Guang Yang <guangyang@fb.com>
1 parent 9eb8d01 commit 7b39a0c

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import logging
98
import copy
109
import json
11-
import torch
12-
from lm_eval.evaluator import simple_evaluate
1310

1411
from typing import List, Optional, Tuple
1512

@@ -18,38 +15,45 @@
1815

1916
from executorch.examples.models.llama.eval_llama_lib import (
2017
build_args_parser,
21-
GraphModuleEvalWrapper
18+
GraphModuleEvalWrapper,
2219
)
2320

24-
from pytorch_tokenizers import get_tokenizer
25-
2621
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
27-
LlamaModel,
28-
ModelArgs,
22+
LlamaModel,
23+
ModelArgs,
2924
)
25+
from lm_eval.evaluator import simple_evaluate
26+
27+
from pytorch_tokenizers import get_tokenizer
3028

3129

3230
class WrappedLlamaModel(nn.Module):
33-
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device='cuda'):
31+
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
3432
super(WrappedLlamaModel, self).__init__()
3533
self.model = model
3634
self.max_seq_len = max_seq_len
3735
self.use_kv_cache = use_kv_cache
3836
self.device = device
3937

40-
def forward(self,
38+
def forward(
39+
self,
4140
tokens: torch.Tensor,
4241
input_pos: Optional[torch.Tensor] = None,
4342
*args,
4443
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
4544
# Pad input if necessary, since LlamaModel requires static shape
4645
if tokens.shape[1] != self.max_seq_len:
47-
tokens = torch.nn.functional.pad(tokens, (self.max_seq_len - tokens.shape[1],0))
48-
atten_mask = self.model.get_example_inputs(self.use_kv_cache)[1].to(device=self.device).to(dtype=torch.bfloat16)
46+
tokens = torch.nn.functional.pad(
47+
tokens, (self.max_seq_len - tokens.shape[1], 0)
48+
)
49+
atten_mask = (
50+
self.model.get_example_inputs(self.use_kv_cache)[1]
51+
.to(device=self.device)
52+
.to(dtype=torch.bfloat16)
53+
)
4954
return self.model.forward(tokens, atten_mask, input_pos, *args)
5055

5156

52-
5357
def gen_eval_wrapper(model_name, args):
5458
tokenizer = get_tokenizer(args.tokenizer_path)
5559
with open(args.params) as f:
@@ -66,7 +70,13 @@ def gen_eval_wrapper(model_name, args):
6670
)
6771
config = prefill_config
6872
use_i64_token = args.embedding_quantize is not None
69-
model = LlamaModel(config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, use_i64_token=use_i64_token)
73+
model = LlamaModel(
74+
config,
75+
ar_len=args.prefill_ar_len,
76+
output_new_cache_only=True,
77+
output_cache=False,
78+
use_i64_token=use_i64_token,
79+
)
7080
state_dict = torch.load(
7181
args.checkpoint, weights_only=True, map_location=args.device, mmap=True
7282
)
@@ -111,7 +121,9 @@ def permute(w, heads):
111121
model.to(dtype=torch.bfloat16)
112122
model.to(args.device)
113123

114-
wrapped_model = WrappedLlamaModel(model, args.use_kv_cache, args.max_seq_length, args.device)
124+
wrapped_model = WrappedLlamaModel(
125+
model, args.use_kv_cache, args.max_seq_length, args.device
126+
)
115127

116128
return GraphModuleEvalWrapper(
117129
model=wrapped_model,
@@ -123,7 +135,6 @@ def permute(w, heads):
123135
)
124136

125137

126-
127138
def eval_llama(
128139
model_name: str,
129140
args: argparse.Namespace,
@@ -166,7 +177,7 @@ def main() -> None:
166177
args.use_kv_cache = False
167178
args.prefill_ar_len = args.max_seq_length
168179

169-
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
180+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
170181

171182
eval_llama(modelname, args)
172183

0 commit comments

Comments
 (0)