Skip to content

Added quantization for evaluation script #11822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/qualcomm/oss_scripts/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ python_binary(
name = "eval_llama_qnn",
srcs = ["eval_llama_qnn.py"],
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:model_sharding_py",
],
deps = [
":llama_lib",
"//executorch/examples/models/llama:eval_library",
Expand Down
120 changes: 107 additions & 13 deletions examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
import argparse
import copy
import json
import torch
from functools import partial

from lm_eval.evaluator import simple_evaluate

from typing import List, Optional, Tuple

Expand All @@ -26,32 +31,53 @@

from pytorch_tokenizers import get_tokenizer

from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate

from executorch.examples.qualcomm.utils import make_quantizer

from executorch.examples.models.llama.source_transformation.quantize import (
get_quant_embedding_transform,
)

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_linear_16a8w_in_affine_layer,
annotate_matmul_16a8w,
)


import logging
sys.setrecursionlimit(4096)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.getLogger().setLevel(logging.INFO)


class WrappedLlamaModel(nn.Module):
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
def __init__(self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"):
super(WrappedLlamaModel, self).__init__()
self.model = model
self.max_seq_len = max_seq_len
self.use_kv_cache = use_kv_cache
self.device = device
self.atten_mask = atten_mask

def forward(
self,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
*args,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
# Pad input if necessary, since LlamaModel requires static shape
if tokens.shape[1] != self.max_seq_len:
tokens = torch.nn.functional.pad(
tokens, (self.max_seq_len - tokens.shape[1], 0)
tokens, (0, self.max_seq_len - tokens.shape[1])
)
atten_mask = (
self.model.get_example_inputs(self.use_kv_cache)[1]
.to(device=self.device)
.to(dtype=torch.bfloat16)
)
return self.model.forward(tokens, atten_mask, input_pos, *args)
return self.model.forward(tokens, self.atten_mask)


def gen_eval_wrapper(model_name, args):
Expand Down Expand Up @@ -119,14 +145,73 @@ def permute(w, heads):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.bfloat16)
model.to(args.device)
model.to(device=args.device)

tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.bfloat16)
inputs = (tokens, atten_mask)

if args.embedding_quantize:
model = get_quant_embedding_transform(
embedding_quantize=args.embedding_quantize
)(model)

model = convert_linear_to_conv2d(model)

wrapped_model = WrappedLlamaModel(
model, args.use_kv_cache, args.max_seq_length, args.device
if args.ptq:
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")

custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)

model.has_quant_io = True

with torch.no_grad():
model = torch.export.export(
model, inputs, strict=True
).module()
if quant_dtype == QuantDtype.use_16a4w_block:
conv_nodes = [
n for n in model.graph.nodes if "conv" in n.name
]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
quantizer.set_block_size_map(block_size_map)

model = prepare_pt2e(model, quantizer)

logging.info("Quantizing the model...")

calibrate(
inputs,
'Once upon a time',
model,
tokenizer=tokenizer,
ar_len=args.prefill_ar_len,
max_seq_len=args.max_seq_len,
kv_updater=None,
use_i64_token=use_i64_token,
)

model = convert_pt2e(model)

model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
)

return GraphModuleEvalWrapper(
model=wrapped_model,
model=model,
tokenizer=tokenizer,
max_seq_length=args.calibration_seq_length,
use_kv_cache=args.use_kv_cache,
Expand Down Expand Up @@ -167,6 +252,7 @@ def main() -> None:
modelname = "llama2"
parser = build_args_parser()
args = parser.parse_args()
args.llama_model = "llama3_2"
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True

Expand All @@ -177,7 +263,15 @@ def main() -> None:
args.use_kv_cache = False
args.prefill_ar_len = args.max_seq_length

# To do fewer samples for faster evaluation
args.limit = 0.1
# args.samples = {'wikitext': list(range(1))}

args.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(args.device)

args.ptq = '8a8w'


eval_llama(modelname, args)

Expand Down
Loading