|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import sys |
7 | 8 | import argparse
|
8 | 9 | import copy
|
9 | 10 | import json
|
| 11 | +import torch |
| 12 | +from functools import partial |
| 13 | + |
| 14 | +from lm_eval.evaluator import simple_evaluate |
10 | 15 |
|
11 | 16 | from typing import List, Optional, Tuple
|
12 | 17 |
|
|
26 | 31 |
|
27 | 32 | from pytorch_tokenizers import get_tokenizer
|
28 | 33 |
|
| 34 | +from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate |
| 35 | + |
| 36 | +from executorch.examples.qualcomm.utils import make_quantizer |
| 37 | + |
| 38 | +from executorch.examples.models.llama.source_transformation.quantize import ( |
| 39 | + get_quant_embedding_transform, |
| 40 | +) |
| 41 | + |
| 42 | +from torchao.quantization.pt2e import MinMaxObserver |
| 43 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 44 | + |
| 45 | + |
| 46 | +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| 47 | +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
| 48 | + |
| 49 | + |
| 50 | +import logging |
| 51 | +sys.setrecursionlimit(4096) |
| 52 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 53 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 54 | +logging.getLogger().setLevel(logging.INFO) |
| 55 | + |
29 | 56 |
|
30 | 57 | class WrappedLlamaModel(nn.Module):
|
31 |
| - def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"): |
| 58 | + def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"): |
32 | 59 | super(WrappedLlamaModel, self).__init__()
|
33 | 60 | self.model = model
|
34 | 61 | self.max_seq_len = max_seq_len
|
35 | 62 | self.use_kv_cache = use_kv_cache
|
36 | 63 | self.device = device
|
| 64 | + self.atten_mask = atten_mask |
37 | 65 |
|
38 | 66 | def forward(
|
39 | 67 | self,
|
40 | 68 | tokens: torch.Tensor,
|
41 |
| - input_pos: Optional[torch.Tensor] = None, |
42 | 69 | *args,
|
43 | 70 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
44 | 71 | # Pad input if necessary, since LlamaModel requires static shape
|
45 | 72 | if tokens.shape[1] != self.max_seq_len:
|
46 | 73 | tokens = torch.nn.functional.pad(
|
47 |
| - tokens, (self.max_seq_len - tokens.shape[1], 0) |
| 74 | + tokens, (0, self.max_seq_len - tokens.shape[1]) |
48 | 75 | )
|
49 |
| - atten_mask = ( |
50 |
| - self.model.get_example_inputs(self.use_kv_cache)[1] |
51 |
| - .to(device=self.device) |
52 |
| - .to(dtype=torch.bfloat16) |
53 |
| - ) |
54 |
| - return self.model.forward(tokens, atten_mask, input_pos, *args) |
| 76 | + return self.model.forward(tokens, self.atten_mask) |
55 | 77 |
|
56 | 78 |
|
57 | 79 | def gen_eval_wrapper(model_name, args):
|
@@ -119,14 +141,69 @@ def permute(w, heads):
|
119 | 141 | layer.feed_forward.prepare_feedfoward_conv()
|
120 | 142 |
|
121 | 143 | model.to(dtype=torch.bfloat16)
|
122 |
| - model.to(args.device) |
| 144 | + model.to(device=args.device) |
| 145 | + |
| 146 | + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) |
| 147 | + tokens = tokens.to(device=args.device) |
| 148 | + atten_mask = atten_mask.to(device=args.device) |
| 149 | + atten_mask = atten_mask.to(dtype=torch.bfloat16) |
| 150 | + inputs = (tokens, atten_mask) |
| 151 | + |
| 152 | + if args.embedding_quantize: |
| 153 | + model = get_quant_embedding_transform( |
| 154 | + embedding_quantize=args.embedding_quantize |
| 155 | + )(model) |
| 156 | + |
| 157 | + model = convert_linear_to_conv2d(model) |
| 158 | + |
| 159 | + if args.ptq: |
| 160 | + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") |
| 161 | + |
| 162 | + custom_annotations = () |
| 163 | + quantizer = make_quantizer( |
| 164 | + quant_dtype=quant_dtype, |
| 165 | + per_channel_conv=True, |
| 166 | + per_channel_linear=True, |
| 167 | + act_observer=MinMaxObserver, |
| 168 | + ) |
| 169 | + quantizer.add_custom_quant_annotations(custom_annotations) |
| 170 | + |
| 171 | + model.has_quant_io = True |
| 172 | + |
| 173 | + with torch.no_grad(): |
| 174 | + model = torch.export.export( |
| 175 | + model, inputs, strict=True |
| 176 | + ).module() |
| 177 | + if quant_dtype == QuantDtype.use_16a4w_block: |
| 178 | + conv_nodes = [ |
| 179 | + n for n in model.graph.nodes if "conv" in n.name |
| 180 | + ] |
| 181 | + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} |
| 182 | + quantizer.set_block_size_map(block_size_map) |
| 183 | + |
| 184 | + model = prepare_pt2e(model, quantizer) |
| 185 | + |
| 186 | + logging.info("Quantizing the model...") |
| 187 | + |
| 188 | + calibrate( |
| 189 | + inputs, |
| 190 | + 'Once upon a time', |
| 191 | + model, |
| 192 | + tokenizer=tokenizer, |
| 193 | + ar_len=args.prefill_ar_len, |
| 194 | + max_seq_len=args.max_seq_len, |
| 195 | + kv_updater=None, |
| 196 | + use_i64_token=use_i64_token, |
| 197 | + ) |
123 | 198 |
|
124 |
| - wrapped_model = WrappedLlamaModel( |
125 |
| - model, args.use_kv_cache, args.max_seq_length, args.device |
| 199 | + model = convert_pt2e(model) |
| 200 | + |
| 201 | + model = WrappedLlamaModel( |
| 202 | + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device |
126 | 203 | )
|
127 | 204 |
|
128 | 205 | return GraphModuleEvalWrapper(
|
129 |
| - model=wrapped_model, |
| 206 | + model=model, |
130 | 207 | tokenizer=tokenizer,
|
131 | 208 | max_seq_length=args.calibration_seq_length,
|
132 | 209 | use_kv_cache=args.use_kv_cache,
|
@@ -177,7 +254,15 @@ def main() -> None:
|
177 | 254 | args.use_kv_cache = False
|
178 | 255 | args.prefill_ar_len = args.max_seq_length
|
179 | 256 |
|
| 257 | + # To do fewer samples for faster evaluation |
| 258 | + args.limit = 0.1 |
| 259 | + # args.samples = {'wikitext': list(range(1))} |
| 260 | + |
180 | 261 | args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 262 | + torch.set_default_device(args.device) |
| 263 | + |
| 264 | + args.ptq = '16a4w' |
| 265 | + |
181 | 266 |
|
182 | 267 | eval_llama(modelname, args)
|
183 | 268 |
|
|
0 commit comments