Skip to content

Commit 4fba0c7

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Added quantization for evaluation script
Summary: Added quantization to evaluation script. Quantization causes deterioriation in accuracy On wikitext task: | Model Name | max_seq_len | ptq | word_perplexity |----------|----------|----------|-----------| | Llama 3.2-1B Instruct | 128 | 16a4w | 5821003.055178451 | | Llama 3.2-1B Instruct | 128 | 16a4w_block | 5396240.078572427 | | Llama 3.2-1B Instruct | 128 | 8a8w | 533154.970440251 | Differential Revision: D76837572
1 parent 496cb05 commit 4fba0c7

File tree

2 files changed

+101
-13
lines changed

2 files changed

+101
-13
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ python_binary(
4949
name = "eval_llama_qnn",
5050
srcs = ["eval_llama_qnn.py"],
5151
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
52+
preload_deps = [
53+
"//executorch/extension/llm/custom_ops:model_sharding_py",
54+
],
5255
deps = [
5356
":llama_lib",
5457
"//executorch/examples/models/llama:eval_library",

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import sys
78
import argparse
89
import copy
910
import json
11+
import torch
12+
from functools import partial
13+
14+
from lm_eval.evaluator import simple_evaluate
1015

1116
from typing import List, Optional, Tuple
1217

@@ -26,32 +31,49 @@
2631

2732
from pytorch_tokenizers import get_tokenizer
2833

34+
from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate
35+
36+
from executorch.examples.qualcomm.utils import make_quantizer
37+
38+
from executorch.examples.models.llama.source_transformation.quantize import (
39+
get_quant_embedding_transform,
40+
)
41+
42+
from torchao.quantization.pt2e import MinMaxObserver
43+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
44+
45+
46+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
47+
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
48+
49+
50+
import logging
51+
sys.setrecursionlimit(4096)
52+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
53+
logging.basicConfig(level=logging.INFO, format=FORMAT)
54+
logging.getLogger().setLevel(logging.INFO)
55+
2956

3057
class WrappedLlamaModel(nn.Module):
31-
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
58+
def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"):
3259
super(WrappedLlamaModel, self).__init__()
3360
self.model = model
3461
self.max_seq_len = max_seq_len
3562
self.use_kv_cache = use_kv_cache
3663
self.device = device
64+
self.atten_mask = atten_mask
3765

3866
def forward(
3967
self,
4068
tokens: torch.Tensor,
41-
input_pos: Optional[torch.Tensor] = None,
4269
*args,
4370
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
4471
# Pad input if necessary, since LlamaModel requires static shape
4572
if tokens.shape[1] != self.max_seq_len:
4673
tokens = torch.nn.functional.pad(
47-
tokens, (self.max_seq_len - tokens.shape[1], 0)
74+
tokens, (0, self.max_seq_len - tokens.shape[1])
4875
)
49-
atten_mask = (
50-
self.model.get_example_inputs(self.use_kv_cache)[1]
51-
.to(device=self.device)
52-
.to(dtype=torch.bfloat16)
53-
)
54-
return self.model.forward(tokens, atten_mask, input_pos, *args)
76+
return self.model.forward(tokens, self.atten_mask)
5577

5678

5779
def gen_eval_wrapper(model_name, args):
@@ -119,14 +141,69 @@ def permute(w, heads):
119141
layer.feed_forward.prepare_feedfoward_conv()
120142

121143
model.to(dtype=torch.bfloat16)
122-
model.to(args.device)
144+
model.to(device=args.device)
145+
146+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
147+
tokens = tokens.to(device=args.device)
148+
atten_mask = atten_mask.to(device=args.device)
149+
atten_mask = atten_mask.to(dtype=torch.bfloat16)
150+
inputs = (tokens, atten_mask)
151+
152+
if args.embedding_quantize:
153+
model = get_quant_embedding_transform(
154+
embedding_quantize=args.embedding_quantize
155+
)(model)
156+
157+
model = convert_linear_to_conv2d(model)
158+
159+
if args.ptq:
160+
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
161+
162+
custom_annotations = ()
163+
quantizer = make_quantizer(
164+
quant_dtype=quant_dtype,
165+
per_channel_conv=True,
166+
per_channel_linear=True,
167+
act_observer=MinMaxObserver,
168+
)
169+
quantizer.add_custom_quant_annotations(custom_annotations)
170+
171+
model.has_quant_io = True
172+
173+
with torch.no_grad():
174+
model = torch.export.export(
175+
model, inputs, strict=True
176+
).module()
177+
if quant_dtype == QuantDtype.use_16a4w_block:
178+
conv_nodes = [
179+
n for n in model.graph.nodes if "conv" in n.name
180+
]
181+
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
182+
quantizer.set_block_size_map(block_size_map)
183+
184+
model = prepare_pt2e(model, quantizer)
185+
186+
logging.info("Quantizing the model...")
187+
188+
calibrate(
189+
inputs,
190+
'Once upon a time',
191+
model,
192+
tokenizer=tokenizer,
193+
ar_len=args.prefill_ar_len,
194+
max_seq_len=args.max_seq_len,
195+
kv_updater=None,
196+
use_i64_token=use_i64_token,
197+
)
123198

124-
wrapped_model = WrappedLlamaModel(
125-
model, args.use_kv_cache, args.max_seq_length, args.device
199+
model = convert_pt2e(model)
200+
201+
model = WrappedLlamaModel(
202+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
126203
)
127204

128205
return GraphModuleEvalWrapper(
129-
model=wrapped_model,
206+
model=model,
130207
tokenizer=tokenizer,
131208
max_seq_length=args.calibration_seq_length,
132209
use_kv_cache=args.use_kv_cache,
@@ -177,7 +254,15 @@ def main() -> None:
177254
args.use_kv_cache = False
178255
args.prefill_ar_len = args.max_seq_length
179256

257+
# To do fewer samples for faster evaluation
258+
args.limit = 0.1
259+
# args.samples = {'wikitext': list(range(1))}
260+
180261
args.device = "cuda" if torch.cuda.is_available() else "cpu"
262+
torch.set_default_device(args.device)
263+
264+
args.ptq = '16a4w'
265+
181266

182267
eval_llama(modelname, args)
183268

0 commit comments

Comments
 (0)