From d3228aa2d6af8b1069b16d479c1a39597751eae4 Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Fri, 20 Jun 2025 14:41:30 -0700 Subject: [PATCH] Added quantization for evaluation script (#11822) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11822 Added quantization to evaluation script. Quantization causes deterioriation in accuracy On wikitext task: | Model Name | max_seq_len | ptq | word_perplexity |----------|----------|----------|-----------| | Llama 3.2-1B Instruct | 128 | 16a4w | 5821003.055178451 | | Llama 3.2-1B Instruct | 128 | 16a4w_block | 5396240.078572427 | | Llama 3.2-1B Instruct | 128 | 8a8w | 533154.970440251 | Reviewed By: cccclai Differential Revision: D76837572 --- examples/qualcomm/oss_scripts/llama/TARGETS | 3 + .../oss_scripts/llama/eval_llama_qnn.py | 120 ++++++++++++++++-- 2 files changed, 110 insertions(+), 13 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS index aee00c44c76..9c5dd1ceaf9 100644 --- a/examples/qualcomm/oss_scripts/llama/TARGETS +++ b/examples/qualcomm/oss_scripts/llama/TARGETS @@ -49,6 +49,9 @@ python_binary( name = "eval_llama_qnn", srcs = ["eval_llama_qnn.py"], main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main", + preload_deps = [ + "//executorch/extension/llm/custom_ops:model_sharding_py", + ], deps = [ ":llama_lib", "//executorch/examples/models/llama:eval_library", diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index bb864a07429..1a8cbbb3de3 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -4,9 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import sys import argparse import copy import json +import torch +from functools import partial + +from lm_eval.evaluator import simple_evaluate from typing import List, Optional, Tuple @@ -26,32 +31,53 @@ from pytorch_tokenizers import get_tokenizer +from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate + +from executorch.examples.qualcomm.utils import make_quantizer + +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) + +from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_linear_16a8w_in_affine_layer, + annotate_matmul_16a8w, +) + + +import logging +sys.setrecursionlimit(4096) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + class WrappedLlamaModel(nn.Module): - def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"): + def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"): super(WrappedLlamaModel, self).__init__() self.model = model self.max_seq_len = max_seq_len self.use_kv_cache = use_kv_cache self.device = device + self.atten_mask = atten_mask def forward( self, tokens: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, *args, ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: # Pad input if necessary, since LlamaModel requires static shape if tokens.shape[1] != self.max_seq_len: tokens = torch.nn.functional.pad( - tokens, (self.max_seq_len - tokens.shape[1], 0) + tokens, (0, self.max_seq_len - tokens.shape[1]) ) - atten_mask = ( - self.model.get_example_inputs(self.use_kv_cache)[1] - .to(device=self.device) - .to(dtype=torch.bfloat16) - ) - return self.model.forward(tokens, atten_mask, input_pos, *args) + return self.model.forward(tokens, self.atten_mask) def gen_eval_wrapper(model_name, args): @@ -119,14 +145,73 @@ def permute(w, heads): layer.feed_forward.prepare_feedfoward_conv() model.to(dtype=torch.bfloat16) - model.to(args.device) + model.to(device=args.device) + + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) + tokens = tokens.to(device=args.device) + atten_mask = atten_mask.to(device=args.device) + atten_mask = atten_mask.to(dtype=torch.bfloat16) + inputs = (tokens, atten_mask) + + if args.embedding_quantize: + model = get_quant_embedding_transform( + embedding_quantize=args.embedding_quantize + )(model) + + model = convert_linear_to_conv2d(model) - wrapped_model = WrappedLlamaModel( - model, args.use_kv_cache, args.max_seq_length, args.device + if args.ptq: + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": + custom_annotations = custom_annotations + ( + annotate_linear_16a8w_in_affine_layer, + ) + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations(custom_annotations) + + model.has_quant_io = True + + with torch.no_grad(): + model = torch.export.export( + model, inputs, strict=True + ).module() + if quant_dtype == QuantDtype.use_16a4w_block: + conv_nodes = [ + n for n in model.graph.nodes if "conv" in n.name + ] + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + quantizer.set_block_size_map(block_size_map) + + model = prepare_pt2e(model, quantizer) + + logging.info("Quantizing the model...") + + calibrate( + inputs, + 'Once upon a time', + model, + tokenizer=tokenizer, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + kv_updater=None, + use_i64_token=use_i64_token, + ) + + model = convert_pt2e(model) + + model = WrappedLlamaModel( + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device ) return GraphModuleEvalWrapper( - model=wrapped_model, + model=model, tokenizer=tokenizer, max_seq_length=args.calibration_seq_length, use_kv_cache=args.use_kv_cache, @@ -167,6 +252,7 @@ def main() -> None: modelname = "llama2" parser = build_args_parser() args = parser.parse_args() + args.llama_model = "llama3_2" # Overrides this arg, because evaluation requires full logits. args.generate_full_logits = True @@ -177,7 +263,15 @@ def main() -> None: args.use_kv_cache = False args.prefill_ar_len = args.max_seq_length + # To do fewer samples for faster evaluation + args.limit = 0.1 + # args.samples = {'wikitext': list(range(1))} + args.device = "cuda" if torch.cuda.is_available() else "cpu" + torch.set_default_device(args.device) + + args.ptq = '8a8w' + eval_llama(modelname, args)