From 650a85069c24ae288bdc638a5c91c6601e2b59ce Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Wed, 7 May 2025 07:59:42 +0000 Subject: [PATCH 01/11] QNN Backend: Integrate custom op package TMANOpPackage to efficiently support GPTQ and BitNet models - Integrate OSS runner into LlamaDemo - Fix several tman_linear & tman_bitnet_linear accuracy issues - Fix BitNet - Add BitNet support - Fix tiling parameters - Optimize performance with DMA & Tiling - Merge SHAs; Support torch.split_with_sizes; Workaround fix for custom op registration bug of multiple graphs compilation - Integrate into LlamaDemo - Add supports for symmetric quantization - Clean debug code - Fix uint16 correctness issue under certain inputs - Support inference with fp16 and TMANOpPackage - Define custom_op, annotators and NodeVisitor in Python side - Add custom op package TMANOpPackage --- backends/qualcomm/__init__.py | 1 + backends/qualcomm/_passes/layout_transform.py | 2 + backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/custom_ops.py | 158 +++ backends/qualcomm/builders/op_tman_linear.py | 319 +++++ backends/qualcomm/builders/qnn_constants.py | 25 + backends/qualcomm/builders/utils.py | 203 ++- backends/qualcomm/partition/common_defs.py | 2 + backends/qualcomm/quantizer/annotators.py | 27 + .../qualcomm/quantizer/custom_annotation.py | 9 + backends/qualcomm/runtime/QnnManager.cpp | 82 +- .../runtime/backends/QnnBackendCommon.cpp | 59 + .../runtime/backends/QnnBackendFactory.cpp | 8 +- .../runtime/backends/QnnBackendFactory.h | 3 +- .../qualcomm/runtime/op_packages/.gitignore | 1 + .../op_packages/TMANOpPackage/Makefile | 342 +++++ .../TMANOpPackage/config/TMANOpPackageHtp.xml | 110 ++ .../TMANOpPackage/include/hvx_funcs.h | 752 +++++++++++ .../src/TMANOpPackageInterface.cpp | 288 ++++ .../TMANOpPackage/src/fp_extend.cpp | 110 ++ .../TMANOpPackage/src/fp_trunc.cpp | 134 ++ .../TMANOpPackage/src/ops/TMANFinalize.cpp | 102 ++ .../TMANOpPackage/src/ops/TMANLinear.cpp | 182 +++ .../TMANOpPackage/src/ops/TMANPrecompute.cpp | 115 ++ backends/qualcomm/tests/test_qnn_manager.py | 27 + backends/qualcomm/tests/test_tman_linear.py | 41 + backends/qualcomm/utils/utils.py | 217 +++ .../app/src/main/res/layout/activity_main.xml | 2 +- .../qualcomm/oss_scripts/bitnet/bitnet.py | 1186 +++++++++++++++++ .../oss_scripts/bitnet/model/__init__.py | 0 .../bitnet/model/configuration_bitnet.py | 147 ++ .../oss_scripts/bitnet/model/static_bitnet.py | 574 ++++++++ .../llama/convert_gptq_weights_to_llama.py | 214 +++ .../llama/convert_hf_weights_to_llama.py | 173 +++ examples/qualcomm/oss_scripts/llama/llama.py | 70 +- .../oss_scripts/llama/model/static_llama.py | 183 +++ .../oss_scripts/llama/qnn_llama_runner.cpp | 7 +- .../oss_scripts/llama/runner/CMakeLists.txt | 79 ++ .../oss_scripts/llama/runner/io_manager.cpp | 111 +- .../oss_scripts/llama/runner/io_manager.h | 17 +- .../oss_scripts/llama/runner/runner.cpp | 124 +- .../oss_scripts/llama/runner/runner.h | 35 +- extension/android/CMakeLists.txt | 19 +- extension/android/jni/jni_layer_llama.cpp | 20 + 44 files changed, 6166 insertions(+), 116 deletions(-) create mode 100644 backends/qualcomm/__init__.py create mode 100644 backends/qualcomm/builders/custom_ops.py create mode 100644 backends/qualcomm/builders/op_tman_linear.py create mode 100644 backends/qualcomm/runtime/op_packages/.gitignore create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/Makefile create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/config/TMANOpPackageHtp.xml create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/include/hvx_funcs.h create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/TMANOpPackageInterface.cpp create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_extend.cpp create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_trunc.cpp create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANFinalize.cpp create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp create mode 100644 backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANPrecompute.cpp create mode 100644 backends/qualcomm/tests/test_qnn_manager.py create mode 100644 backends/qualcomm/tests/test_tman_linear.py create mode 100644 examples/qualcomm/oss_scripts/bitnet/bitnet.py create mode 100644 examples/qualcomm/oss_scripts/bitnet/model/__init__.py create mode 100644 examples/qualcomm/oss_scripts/bitnet/model/configuration_bitnet.py create mode 100644 examples/qualcomm/oss_scripts/bitnet/model/static_bitnet.py create mode 100644 examples/qualcomm/oss_scripts/llama/convert_gptq_weights_to_llama.py create mode 100644 examples/qualcomm/oss_scripts/llama/convert_hf_weights_to_llama.py create mode 100644 examples/qualcomm/oss_scripts/llama/runner/CMakeLists.txt diff --git a/backends/qualcomm/__init__.py b/backends/qualcomm/__init__.py new file mode 100644 index 00000000000..b4065c99c12 --- /dev/null +++ b/backends/qualcomm/__init__.py @@ -0,0 +1 @@ +from .builders.custom_ops import * diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 19c5417f8f8..05668a0821d 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -100,6 +100,8 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.where.self, _operator.getitem, torch.ops.aten.scalar_tensor.default, + exir_ops.edge.tman.linear.default, + exir_ops.edge.tman.bitnet_linear.default, } layout_type = { diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 705d5d163cd..b218386844c 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -82,6 +82,7 @@ op_sub, op_sum_int_list, op_tanh, + op_tman_linear, op_to, op_topk, op_transpose, @@ -170,6 +171,7 @@ op_sub, op_sum_int_list, op_tanh, + op_tman_linear, op_topk, op_to, op_transpose, diff --git a/backends/qualcomm/builders/custom_ops.py b/backends/qualcomm/builders/custom_ops.py new file mode 100644 index 00000000000..ac3f6b5c900 --- /dev/null +++ b/backends/qualcomm/builders/custom_ops.py @@ -0,0 +1,158 @@ +import torch +from .utils import unpack_weights + + +def _dequantize_weight( + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + g_idx: torch.Tensor, + wf_unsqueeze_zero: torch.Tensor, + wf_unsqueeze_neg_one: torch.Tensor, + bits: int, +) -> torch.Tensor: + """ + Based on dequantize_weights in gptqmodel/nn_modules/qlinear/__init__.py + """ + import torch as t + + num_itr = 1 # desc_act=False + assert(qweight.dtype == t.int32 and qzeros.dtype == t.int32) + pack_factor = 32 // bits + dequant_dtype = t.int16 if bits == 8 else t.int8 + maxq = 2 ** bits - 1 + + if bits in [2, 4, 8]: + zeros = t.bitwise_right_shift( + t.unsqueeze(qzeros, 2).expand(-1, -1, pack_factor), + wf_unsqueeze_zero # wf.unsqueeze(0), + ).to(dequant_dtype) + zeros = t.bitwise_and(zeros, maxq).reshape(scales.shape) + + weight = t.bitwise_and( + t.bitwise_right_shift( + t.unsqueeze(qweight, 1).expand(-1, pack_factor, -1), + wf_unsqueeze_neg_one # wf.unsqueeze(-1) + ).to(dequant_dtype), + maxq + ) + elif bits == 3: + zeros = qzeros.reshape(qzeros.shape[0], qzeros.shape[1] // 3, 3, 1).expand( + -1, -1, -1, 12 + ) + zeros = zeros >> wf_unsqueeze_zero # wf.unsqueeze(0) + zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) + zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) + zeros = zeros & 0x7 + zeros = t.cat( + [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], + dim=2, + ).reshape(scales.shape) + + weight = qweight.reshape(qweight.shape[0] // 3, 3, 1, qweight.shape[1]).expand( + -1, -1, 12, -1 + ) + weight = (weight >> wf_unsqueeze_neg_one) & 0x7 # wf.unsqueeze(-1) + weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) + weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) + weight = weight & 0x7 + weight = t.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + if num_itr == 1: + weights = scales[g_idx.long()] * (weight - zeros[g_idx.long()]) + else: + num_dim = g_idx.shape[0] // num_itr + weights = [] + for i in range(num_itr): + scale_i = scales[:, i * num_dim: (i + 1) * num_dim] + weight_i = weight[:, i * num_dim: (i + 1) * num_dim] + zeros_i = zeros[:, i * num_dim: (i + 1) * num_dim] + g_idx_i = g_idx[i * num_dim: (i + 1) * num_dim].long() + weights.append(scale_i[g_idx_i] * (weight_i - zeros_i[g_idx_i])) + weights = t.cat(weights, dim=1) + + return weights + + +@torch.library.custom_op("tman::linear", mutates_args=()) +def tman_linear( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + g_idx: torch.Tensor, + wf_unsqueeze_zero: torch.Tensor, + wf_unsqueeze_neg_one: torch.Tensor, + group_size: int, + bits: int, + symmetric: bool, + gptq_v2: bool, +) -> torch.Tensor: + out_features = qweight.shape[1] + out_shape = x.shape[:-1] + (out_features,) + x = x.reshape(-1, x.shape[-1]) + weights = _dequantize_weight( + qweight, + scales, + qzeros, + g_idx, + wf_unsqueeze_zero, + wf_unsqueeze_neg_one, + bits, + ).to(x.dtype) + out = torch.matmul(x, weights).reshape(out_shape) + return out.to(x.dtype) + + +@tman_linear.register_fake +def tman_linear_fake( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + g_idx: torch.Tensor, + wf_unsqueeze_zero: torch.Tensor, + wf_unsqueeze_neg_one: torch.Tensor, + group_size: int, + bits: int, + symmetric: bool, + gptq_v2: bool, +) -> torch.Tensor: + out_features = qweight.shape[1] + out_shape = x.shape[:-1] + (out_features,) + return x.new_zeros(out_shape) + + +@torch.library.custom_op("tman::bitnet_linear", mutates_args=()) +def tman_bitnet_linear( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + # unpack weights + w = weight + w_quant = unpack_weights(w, dtype=x.dtype) + # activation_quant + num_bits = 8 + Qn = -(2 ** (num_bits - 1)) + Qp = 2 ** (num_bits - 1) - 1 + scale = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) + result = (x * scale).round().clamp(Qn, Qp) + input_quant, input_scale = result.to(torch.int8), scale + # linear + y = torch.nn.functional.linear(input_quant.to(x.dtype), w_quant) + y = y / input_scale * weight_scale + return y + + +@tman_bitnet_linear.register_fake +def tman_bitnet_linear_fake( + x: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, +) -> torch.Tensor: + VALUES_PER_ITEM = 4 + out_features = weight.shape[0] * VALUES_PER_ITEM + out_shape = x.shape[:-1] + (out_features,) + return x.new_zeros(out_shape) diff --git a/backends/qualcomm/builders/op_tman_linear.py b/backends/qualcomm/builders/op_tman_linear.py new file mode 100644 index 00000000000..97aa9f479d6 --- /dev/null +++ b/backends/qualcomm/builders/op_tman_linear.py @@ -0,0 +1,319 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA +from executorch.backends.qualcomm.builders.utils import unpack_gptqv2, hvx_preprocess_weights, unpack_weights +import logging + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import ( + OpTMANLinear, + OpTMANPrecompute, + OpTMANFinalize, + OpConvert, + QNN_OP_PACKAGE_NAME_TMAN, + QNN_OP_PACKAGE_NAME_QTI_AISW, +) +from .utils import get_parameter + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _get_c_size( + m: int, + bits: int, +) -> int: + # float32 + c_size = m * bits + return c_size * 4 + + +def _get_l_size( + k: int, + group_size: int, + need_dequant: bool = True, +) -> int: + LUT_G = 4 + LUT_SIZE = 16 + ACT_GROUP_SIZE = 256 + # float16 + x_size = k if need_dequant else 0 + # int16 + l_size = k // LUT_G * LUT_SIZE + # float32 + ls_size = 1 if (ACT_GROUP_SIZE == -1) else (k // ACT_GROUP_SIZE) + # float32 + lb_size = 1 if (group_size == 0) else (k // group_size) + return x_size * 2 + l_size * 2 + max(ls_size * 4, 128) + max(lb_size * 4, 128) + + +def _decide_tile_size( + dim_size: int, + total_size: int, + vtcm_size_in_mb: int = 8, + n_threads: int = 6, + divider: int = 2, +) -> int: + max_tile_size = vtcm_size_in_mb * 1024 * 1024 // n_threads + res = dim_size + success = False + for s in range(dim_size // divider, 0, -1): + chunk_size = s * divider + if dim_size % chunk_size != 0: + continue + res = chunk_size + if total_size // dim_size * res < max_tile_size: + success = True + break + if not success: + logger.warning(f"Can't find optimal tile size that is multiple of {divider} and fits in VTCM, use {res} as workaround") + return res + + +@register_node_visitor +class TMANLinear(NodeVisitor): + target = ["tman.linear.default", "tman.bitnet_linear.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + self.add_convert = True + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + if node.target.__name__ == "tman.linear.default": + qweight_node = node.args[1] + qweight_tensor = get_parameter(qweight_node, self.edge_program) + scales_node = node.args[2] + scales_tensor = get_parameter(scales_node, self.edge_program) + qzeros_node = node.args[3] + qzeros_tensor = get_parameter(qzeros_node, self.edge_program) + group_size = cast(int, node.args[7]) + bits = cast(int, node.args[8]) + symmetric = cast(bool, node.args[9]) + gptq_v2 = cast(bool, node.args[10]) + + qweight_repacked, scales_repacked, zeros_repacked, ref_bits, ref_group_size, ref_symmetric = unpack_gptqv2( + qweight_tensor.detach().numpy(), + scales_tensor.detach().numpy(), + qzeros_tensor.detach().numpy(), + gptq_v2, + ) + assert ref_bits == bits and ref_group_size == group_size and ref_symmetric == symmetric, ( + f"TMANLinear: bits/group_size/symmetric mismatch, {ref_bits}/{ref_group_size}/{ref_symmetric} != {bits}/{group_size}/{symmetric}" + ) + elif node.target.__name__ == "tman.bitnet_linear.default": + qweight_node = node.args[1] + qweight_tensor = get_parameter(qweight_node, self.edge_program) + scales_node = node.args[2] + scales_tensor = get_parameter(scales_node, self.edge_program) + group_size = 0 + bits = 2 + symmetric = True + + qweight_repacked = (unpack_weights(qweight_tensor.detach(), dtype=torch.int8) + 2).to(torch.uint8).numpy() + scales_repacked = scales_tensor.detach().numpy() + zeros_repacked = None + else: + raise NotImplementedError(f"Unsupported node target: {node.target.__name__}") + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + k = input_tensor.shape[-1] + m = output_tensor.shape[-1] + + zeros_repacked = zeros_repacked if not symmetric else None + vec_p = 128 + total_size = qweight_repacked.nbytes + max((scales_repacked.size + (zeros_repacked.size if zeros_repacked is not None else 0)) * np.dtype("float16").itemsize, 128) + tile_p = _decide_tile_size(m*bits, total_size, divider=bits*vec_p) + qweight_repacked, scales_repacked = hvx_preprocess_weights(qweight_repacked, scales_repacked, zeros_repacked, bits, tile_p=tile_p, vec_p=vec_p) + logger.info(f"TMANLinear: m={m}, k={k}, bits={bits}, tile_p={tile_p}, qweight({qweight_repacked.shape})") + + qweight_tensor_wrapper = self.define_tensor( + qweight_node, + node, + torch.from_numpy(qweight_repacked), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + + scales_tensor_wrapper = self.define_tensor( + scales_node, + node, + torch.from_numpy(scales_repacked), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + + # do not quantize scratch buffer + no_quant_encoding, no_quant_configs = ( + PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + {}, + ) + l_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_precompute", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, + quant_encoding=no_quant_encoding, + quant_configs=no_quant_configs, + dims=torch.Size((1, _get_l_size(k, group_size, not self.add_convert))), + tensor=None, # Unused when is_fake_tensor is True + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + c_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_linear", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, + quant_encoding=no_quant_encoding, + quant_configs=no_quant_configs, + dims=torch.Size((1, _get_c_size(m, bits))), + tensor=None, # Unused when is_fake_tensor is True + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + + if self.add_convert: + intermediate_input_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_input_converted", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_16, + quant_encoding=no_quant_encoding, + quant_configs=no_quant_configs, + dims=input_tensor.size(), + tensor=None, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + intermediate_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_output_converted", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_16, + quant_encoding=no_quant_encoding, + quant_configs=no_quant_configs, + dims=output_tensor.size(), + tensor=None, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + input_convert_op = PyQnnWrapper.PyQnnOpWrapper( + node.name + "_input_convert", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpConvert.op_name, + ) + input_convert_op.AddInputTensors([input_tensor_wrapper]) + input_convert_op.AddOutputTensors([intermediate_input_tensor_wrapper]) + + output_convert_op = PyQnnWrapper.PyQnnOpWrapper( + node.name + "_output_convert", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpConvert.op_name, + ) + output_convert_op.AddInputTensors([intermediate_output_tensor_wrapper]) + output_convert_op.AddOutputTensors([output_tensor_wrapper]) + + input_tensor_wrapper = intermediate_input_tensor_wrapper + output_tensor_wrapper = intermediate_output_tensor_wrapper + + precompute_op = PyQnnWrapper.PyQnnOpWrapper( + node.name + "_precompute", + QNN_OP_PACKAGE_NAME_TMAN, + OpTMANPrecompute.op_name, + ) + precompute_op.AddInputTensors([input_tensor_wrapper]) + precompute_op.AddOutputTensors([l_tensor_wrapper]) + precompute_op.AddScalarParam( + OpTMANPrecompute.param_group_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(group_size)}, + ) + precompute_op.AddScalarParam( + OpTMANPrecompute.param_bits, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(bits)}, + ) + precompute_op.AddScalarParam( + OpTMANPrecompute.param_symmetric, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(symmetric)}, + ) + + linear_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_TMAN, + OpTMANLinear.op_name, + ) + linear_op.AddInputTensors([l_tensor_wrapper, qweight_tensor_wrapper, scales_tensor_wrapper]) + linear_op.AddOutputTensors([c_tensor_wrapper]) + linear_op.AddScalarParam( + OpTMANLinear.param_group_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(group_size)}, + ) + linear_op.AddScalarParam( + OpTMANLinear.param_bits, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(bits)}, + ) + linear_op.AddScalarParam( + OpTMANLinear.param_symmetric, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(symmetric)}, + ) + + finalize_op = PyQnnWrapper.PyQnnOpWrapper( + node.name + "_finalize", + QNN_OP_PACKAGE_NAME_TMAN, + OpTMANFinalize.op_name, + ) + finalize_op.AddInputTensors([c_tensor_wrapper]) + finalize_op.AddOutputTensors([output_tensor_wrapper]) + finalize_op.AddScalarParam( + OpTMANFinalize.param_group_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(group_size)}, + ) + finalize_op.AddScalarParam( + OpTMANFinalize.param_bits, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(bits)}, + ) + finalize_op.AddScalarParam( + OpTMANFinalize.param_symmetric, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {QCOM_DATA: np.int32(symmetric)}, + ) + + if self.add_convert: + return [input_convert_op, precompute_op, linear_op, finalize_op, output_convert_op] + return [precompute_op, linear_op, finalize_op] diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 06e398f7c05..2aa8d43ff8b 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -8,6 +8,7 @@ from enum import IntEnum, unique QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw" +QNN_OP_PACKAGE_NAME_TMAN = "TMANOpPackage" # Below constants should be same as those in QNN headers. # Maybe someday we should expose these constants by pybind @@ -511,3 +512,27 @@ class OpTransposeConv2d: class OpUnpack: op_name: str = "UnPack" param_axis: str = "axis" + + +@dataclass(init=False, frozen=True) +class OpTMANLinear: + op_name: str = "TMANLinear" + param_group_size: str = "group_size" + param_bits: str = "bits" + param_symmetric: str = "symmetric" + + +@dataclass(init=False, frozen=True) +class OpTMANPrecompute: + op_name: str = "TMANPrecompute" + param_group_size: str = "group_size" + param_bits: str = "bits" + param_symmetric: str = "symmetric" + + +@dataclass(init=False, frozen=True) +class OpTMANFinalize: + op_name: str = "TMANFinalize" + param_group_size: str = "group_size" + param_bits: str = "bits" + param_symmetric: str = "symmetric" diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index c82ebaf1bb3..89bb3d4d9cd 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Tuple import torch from torch._export.utils import ( @@ -16,6 +16,8 @@ is_param, ) +import numpy as np + def is_parameter( node: torch.fx.Node, edge_program: torch.export.ExportedProgram @@ -125,3 +127,202 @@ def deduce_dtype( return quant_infos["dtype"] return tensor.dtype + + +def parse_gptqv2(qweight: np.ndarray, scales: np.ndarray, qzeros: np.ndarray) -> Tuple: + assert qweight.dtype == "int32" + assert qzeros.dtype == "int32" + + bits = 32 // (scales.shape[1] // qzeros.shape[1]) + K = qweight.shape[0] * (32 // bits) + M = qweight.shape[1] + group_size = K // scales.shape[0] + + return K, M, bits, group_size + + +def unpack_gptqv2(qweight: np.ndarray, scales: np.ndarray, qzeros: np.ndarray, gptq_v2: bool = True): + """ + Unpack GPTQv2 + Return T-MAC biased uint8 weight [0, 2 ** bits), fp16 scales, biased fp16 zeros, bits, group_size + """ + assert qweight.dtype == "int32" + assert qzeros.dtype == "int32" + # TODO: support other pack_dtypes + + K, M, bits, group_size = parse_gptqv2(qweight, scales, qzeros) + + # Detect symmetry + if bits == 2: + sym_zero = 0xaaaaaaaa + elif bits == 4: + sym_zero = 0x88888888 + else: + raise ValueError(f"Unsupported bits: {bits}") + symmetric = not (qzeros - np.uint32(sym_zero).astype(np.int32)).any() + + # Unpack qweight + qweights = [(qweight >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)] + w = np.stack(qweights, axis=1).reshape(K, M).T.astype("uint8") + + scales = scales.T + + # Unpack qzeros + zeros = [(qzeros >> bit_offset) & ((1 << bits) - 1) for bit_offset in range(0, 32, bits)] + zeros = np.stack(zeros, axis=-1).reshape(K // group_size, M).T.astype(scales.dtype) + if not gptq_v2: + # `zeros = zeros - 1` in AutoGPTQ + # Not in GPTQModel + zeros += 1 + zeros = (zeros - (2 ** (bits - 1))) + + return w, scales, zeros, bits, group_size, symmetric + + +def hvx_preprocess_weights( + w: np.ndarray, + scales: np.ndarray, + zeros: Optional[np.ndarray] = None, + bits: int = 4, + g: int = 4, + tile_p: int = 512, + tile_q: int = 64, + vec_p: int = 128, + vec_q: int = 4, + vec_c: int = 32, +) -> Tuple[np.ndarray, np.ndarray]: + + assert w.dtype == "uint8" + assert scales.dtype == "float16" or scales.dtype == "float32" or scales.dtype == "bfloat16" + if scales.dtype != "float16": + scales = scales.astype("float16") + zeros = zeros.astype("float16") if zeros is not None else None + # 4 = sizeof(int32/float) / sizeof(uint8) + assert vec_p // 4 == vec_c + M, K = w.shape + assert M >= vec_p, f"out features {M} should be larger than vec_p {vec_p}" + + P = M * bits + Q = K // g + + # (M, K, bits) + w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1) + # (M, K, bits) -> (M, bits, K) -> (M, bits, K) -> (M, bits, K // g, g) + w = w.transpose(0, 2, 1).reshape(M, bits, Q, g) + # (M, bits, K // g, g) -> (M, bits, Q) + w = sum([(w[:, :, :, ig] << ig) for ig in range(g)]) + # (M, bits, Q) -> (M // vec_p, vec_p, bits, Q) -> (M // vec_p, bits, vec_p, Q) -> (P // vec_p, vec_p, Q) + w = w.reshape(M // vec_p, vec_p, bits, Q).transpose(0, 2, 1, 3) + # Interleave even and odd vec_c of w_vec + # 0, 1 -> even bytes of w_vec -> c_vec_0, c_vec_2 -> c_bitsum_lo + # 2, 3 -> odd bytes of w_vec -> c_vec_1, c_vec_3 -> c_bitsum_hi + # w_vec = w0/w2/w0/w2......w1/w3/w1/w3 + # c_vec_0, c_vec_2 = w0/w0......w1/w1 + # c_vec_1, c_vec_3 = w2/w2......w3/w3 + w = w.reshape(P // vec_p, 2, 2, vec_c, Q).transpose(0, 2, 3, 1, 4) + w = w.reshape(P // tile_p, tile_p, Q // tile_q, tile_q).transpose(0, 2, 1, 3) + # 0 1 2 3 4 5 + w = w.reshape(P // tile_p, Q // tile_q, tile_p // vec_p, vec_p, tile_q // vec_q, vec_q).transpose(0, 1, 2, 4, 5, 3) + # Pack and interleave: q = 0 -> w_vec_lo_bo, q = 1 -> w_vec_lo_to, q = 2 -> w_vec_hi_bo, q = 3 -> w_vec_hi_to + # lo -> low 128 bytes, hi -> high 128 bytes, bo -> bot 4 bit in a byte, to -> top 4 bit in a byte + w = w.reshape(-1, vec_q, vec_p).reshape(-1, vec_q // 2, 2, vec_p).transpose(0, 1, 3, 2) + w = sum([(w[:, :, :, n] << (n * g)) for n in range(2)]) + w = w.reshape(P // tile_p, Q // tile_q, tile_p // vec_p, tile_q // vec_q, vec_q // 2, vec_p) + # Reshape for easy tiling + w = np.ascontiguousarray(w).view(np.int32).reshape(P // tile_p, -1) + + if scales.size >= M: # GPTQ + group_size = K // scales.shape[1] + q_group_size = group_size // g + scales = scales.reshape(P // tile_p, tile_p // bits, Q // tile_q, tile_q // q_group_size).transpose(0, 2, 1, 3) + # 0 1 2 3 4 + scales = scales.reshape(P // tile_p, Q // tile_q, tile_p // bits // vec_p, vec_p, tile_q // q_group_size).transpose(0, 1, 2, 4, 3) + # s_vec = s0/s0......s1/s1......s2/s2......s3/s3 + # s_vec_lo_lo, s_vec_lo_hi = s0/s0......s1/s1 -> c_vec_0, c_vec_2 -> c_bitsum_lo + # no need for interleaving + if zeros is not None: + zeros = zeros.reshape(P // tile_p, tile_p // bits, Q // tile_q, tile_q // q_group_size).transpose(0, 2, 1, 3) + zeros = zeros.reshape(P // tile_p, Q // tile_q, tile_p // bits // vec_p, vec_p, tile_q // q_group_size).transpose(0, 1, 2, 4, 3) + # (c * ls + lb) * s + z * s * lb * 2 + # = (c * ls + lb + z * lb * 2) * s + # = (c * ls + (z * 2 + 1) * lb) * s + zeros = zeros * 2 + 1 + scales = np.stack([scales, zeros], axis=-2) + scales = scales.view(np.int32).reshape(P // tile_p, -1) + else: # BitNet + scales = scales.view(np.uint16).reshape(1, -1) + # [ERROR] [Qnn ExecuTorch]: QnnDsp Dma execution failed on the skel side. result = 1100 transport error = 0 + # Padding to vec_p + # TODO: verify if the padding is needed + if scales.nbytes < vec_p: + scales = np.resize(scales, (1, vec_p // np.dtype("int16").itemsize)) + return w, scales + + +def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """ + Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value. + + Parameters: + ----------- + packed : torch.Tensor + A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value). + dtype : torch.dtype + The dtype of the returned Tensor + Returns: + -------- + torch.Tensor + A tensor of unpacked weights, where each value is converted from its packed 2-bit representation. + + Example: + -------- + packed = torch.tensor([[0b10100001, 0b00011000], + [0b10010000, 0b00001010]], dtype=torch.uint8) + + # Unpack the values + unpacked = unpack_weights(packed) + + # Resulting unpacked tensor + print(unpacked) + # Output: tensor([[ 0, -1], + [-1, 1], + [-1, 1], + [-1, 1], + [ 1, 0], + [ 0, -1], + [ 1, -1], + [ 1, -1]]) + + Explanation of the example: + --------------------------- + Let's take the first value for example 0b10100001, we we will only focus on the first column, + because every element is unpacked across the first dimension + - First 2 bits: `01` → 0 at [0][0] + - Second 2 bits: `00` → -1 at [0][2] + - Third 2 bits: `10` → 1 at [0][4] + - Fourth 2 bits: `10` → 1 at [0][6] + the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7] + + We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible, + we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse + this by subtracting 1 to restore the original ternary values. + """ + VALUES_PER_ITEM = 4 + packed_shape = packed.shape + + if len(packed_shape) == 1: + original_row_dim = packed_shape[0] * VALUES_PER_ITEM + unpacked_shape = (original_row_dim,) + else: + original_row_dim = packed_shape[0] * VALUES_PER_ITEM + unpacked_shape = (original_row_dim, *packed_shape[1:]) + + unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) + + for i in range(VALUES_PER_ITEM): + start = i * packed_shape[0] + end = start + packed_shape[0] + mask = 3 << (2 * i) + unpacked[start:end] = (packed & mask) >> (2 * i) + + return unpacked.to(dtype) - 1 diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 6326f4d1210..ea87b0a1cba 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -28,4 +28,6 @@ allow_list_operator = [ _operator.getitem, + exir_ops.edge.tman.linear.default, + exir_ops.edge.tman.bitnet_linear.default, ] diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..fd08c8647db 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -1224,3 +1224,30 @@ def annotate_zeros(node: Node, quantization_config: QuantizationConfig) -> None: output_qspec=quantization_config.output_activation, _annotated=True, ) + + +@register_annotator([torch.ops.aten.split_with_sizes.default]) +def annotate_split_with_sizes(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) + + +try: + from executorch.backends.qualcomm.builders.custom_ops import tman_linear, tman_bitnet_linear + @register_annotator([torch.ops.tman.linear.default]) + def annotate_tman_linear(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + # We can use single_in_single_out since we don't want to quantize qweight and scales input + annotate_single_in_single_out(node, quantization_config) + + + @register_annotator([torch.ops.tman.bitnet_linear.default]) + def annotate_tman_bitnet_linear(node: Node, quantization_config: QuantizationConfig) -> None: + if _is_annotated([node]): + return + # We can use single_in_single_out since we don't want to quantize qweight and scales input + annotate_single_in_single_out(node, quantization_config) +except ImportError: + pass diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..86ce0783814 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Sequence +import operator import torch from executorch.backends.qualcomm.quantizer.annotators import ( @@ -229,6 +230,14 @@ def annotate_matmul_input1(node: Node): node, quantization_config=quantization_config_8a4w_per_channel ) break + elif node.target in [ + torch.ops.tman.linear.default, + torch.ops.tman.bitnet_linear.default, + torch.ops.aten.split_with_sizes.default, + operator.getitem, + ]: + # TODO: tman::linear currently does not support 8a + break elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: break else: diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 13718b0891a..d14f0cfa8eb 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -24,6 +24,48 @@ namespace executorch { namespace backends { namespace qnn { +namespace { +// TODO: [ERROR] [Qnn ExecuTorch]: tcm_migration.cc:174:ERROR:Memory properties specified twice for operator ::TMANLinear +// The root cause of this error is that when QNN backend is freed, the memory properties of custom ops are not cleared, +// which will cause the error when the QNN backend is loaded and custom ops are registered again. +// This is a bug in QNN SDK, related to DEF_TENSOR_PROPERTIES / hnnx::register_tensor_properties. +// Workaround: prevent the QNN backend from being freed. +class GlobalBackend { +public: + QnnImplementation implementation_; + QnnLogger* logger_; + QnnBackend* backend_; + + static GlobalBackend& GetInstance() { + static GlobalBackend instance; + return instance; + } + ~GlobalBackend() { + if (backend_) { + delete backend_; + backend_ = nullptr; + } + if (logger_) { + delete logger_; + logger_ = nullptr; + } + } +private: + GlobalBackend() + : implementation_("libQnnHtp.so") { + implementation_.Load(nullptr); + logger_ = new QnnLogger( + implementation_, LoggingCallback, QnnExecuTorchLogLevel::kLogLevelWarn); + backend_ = new HtpBackend(implementation_, logger_); + Error error = backend_->Configure(); + if (error != Error::Ok) { + QNN_EXECUTORCH_LOG_ERROR( + "Failed to configure backend. Error code: %d", error); + } + }; +}; +} + using executorch::runtime::Error; bool CompareExportedInput( @@ -92,7 +134,8 @@ QnnManager::QnnManager( break; } } - qnn_loaded_backend_ = QnnImplementation(library_path); + // qnn_loaded_backend_ = QnnImplementation(library_path); + qnn_loaded_backend_ = GlobalBackend::GetInstance().implementation_; backend_params_ptr_ = std::make_unique(); qnn_dlc_manager_ = @@ -170,6 +213,19 @@ Error QnnManager::RegisterMem( void* custom_mem_base = shared_buffer_manager.GetCustomMemBase(data_ptr); if (custom_mem_base != nullptr) { + size_t tensor_bytes = 0; + for (const auto& info : shared_buffer_manager.GetCustomMemTensorInfoSet()) { + if (info.tensor_addr == data_ptr) { + tensor_bytes = info.tensor_bytes; + } + } + if (tensor_bytes != tensor_wrapper->GetBytes()) { + QNN_EXECUTORCH_LOG_WARN( + "Tensor %s size %u is not equal to custom mem size %zu\n", + tensor_wrapper->GetName().c_str(), + tensor_wrapper->GetBytes(), + tensor_bytes); + } return RegisterCustomMem(data_ptr, custom_mem_base, tensor_wrapper); } return RegisterIonMem(data_ptr, tensor_wrapper); @@ -279,8 +335,9 @@ Error QnnManager::RegisterCustomMem( Error QnnManager::Init() { ET_CHECK_OR_RETURN_ERROR( LoadQnnLibrary() == Error::Ok, Internal, "Fail to load Qnn library"); - logger_ = std::make_unique( - qnn_loaded_backend_, LoggingCallback, options_->log_level()); + // logger_ = std::make_unique( + // qnn_loaded_backend_, LoggingCallback, options_->log_level()); + logger_ = std::unique_ptr(GlobalBackend::GetInstance().logger_); if (backend_params_ptr_->backend_init_state_ == BackendInitializeState::UNINITIALIZED) { QNN_EXECUTORCH_LOG_INFO( @@ -292,7 +349,8 @@ Error QnnManager::Init() { logger_.get(), qnn_context_blob_, options_, - qnn_dlc_manager_.get()); + qnn_dlc_manager_.get(), + std::unique_ptr(GlobalBackend::GetInstance().backend_)); ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_ != nullptr, Internal, @@ -301,10 +359,10 @@ Error QnnManager::Init() { backend_params_ptr_->qnn_backend_cache_ptr_->Configure() == Error::Ok, Internal, "Fail to configure Qnn backend cache"); - ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_backend_ptr_->Configure() == Error::Ok, - Internal, - "Fail to configure Qnn backend"); + // ET_CHECK_OR_RETURN_ERROR( + // backend_params_ptr_->qnn_backend_ptr_->Configure() == Error::Ok, + // Internal, + // "Fail to configure Qnn backend"); ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_->qnn_device_ptr_->Configure() == Error::Ok, Internal, @@ -463,11 +521,17 @@ Error QnnManager::ProfileExecuteData( void QnnManager::Destroy() { QNN_EXECUTORCH_LOG_INFO("Destroy Qnn backend parameters"); + if (backend_params_ptr_->qnn_backend_ptr_ != nullptr) { + GlobalBackend::GetInstance().backend_ = backend_params_ptr_->qnn_backend_ptr_.release(); + } + if (logger_ != nullptr) { + GlobalBackend::GetInstance().logger_ = logger_.release(); + } backend_params_ptr_.reset(new BackendConfigParameters()); qnn_dlc_manager_->ResetBackendParams(); logger_.reset(); qnn_dlc_manager_->ResetLogger(); - qnn_loaded_backend_.TerminateAllBackends(); + // qnn_loaded_backend_.TerminateAllBackends(); qnn_dlc_manager_->TerminateAllBackends(); } diff --git a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp index 310e38d1744..85ffcead41b 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCommon.cpp @@ -6,12 +6,30 @@ * LICENSE file in the root directory of this source tree. */ #include +#include namespace executorch { namespace backends { namespace qnn { using executorch::runtime::Error; +namespace { +void split( + std::vector& splitString, + const std::string& tokenizedString, + const char separator) { + splitString.clear(); + std::istringstream tokenizedStringStream(tokenizedString); + while (!tokenizedStringStream.eof()) { + std::string value; + getline(tokenizedStringStream, value, separator); + if (!value.empty()) { + splitString.push_back(value); + } + } +} +} // namespace + QnnBackend::~QnnBackend() { const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; @@ -54,6 +72,47 @@ Error QnnBackend::Configure() { QNN_GET_ERROR_CODE(error)); return Error::Internal; } + + // TODO: Expose API to options in QnnManager later + std::string opPackagePaths = "/data/local/tmp/llama/libQnnTMANOpPackage.so:TMANOpPackageInterfaceProvider:HTP"; + if (const char* env_p = std::getenv("QNN_OP_PACKAGE_PATHS")) { + opPackagePaths = env_p; + } + std::vector m_opPackagePaths; + split(m_opPackagePaths, opPackagePaths, ','); + + const size_t pathIdx = 0; + const size_t interfaceProviderIdx = 1; + for (auto const& opPackagePath : m_opPackagePaths) { + std::vector opPackage; + split(opPackage, opPackagePath, ':'); + const char* target = nullptr; + const size_t targetIdx = 2; + if (opPackage.size() != 2 && opPackage.size() != 3) { + QNN_EXECUTORCH_LOG_ERROR( + "Malformed opPackageString provided: %s", opPackagePath.c_str()); + return Error::Internal; + } + if (opPackage.size() == 3) { + target = opPackage[targetIdx].c_str(); + } + error = qnn_interface.qnn_backend_register_op_package( + handle_, opPackage[pathIdx].c_str(), opPackage[interfaceProviderIdx].c_str(), target); + if (error != QNN_SUCCESS) { + QNN_EXECUTORCH_LOG_ERROR( + "Failed to register " + "op package %s for Backend " + "ID %u, error=%d", + opPackage[pathIdx].c_str(), + qnn_interface.GetBackendId(), + QNN_GET_ERROR_CODE(error)); + return Error::Internal; + } + QNN_EXECUTORCH_LOG_INFO( + "Registered Op Package: %s and interface provider: %s", + opPackage[pathIdx].c_str(), + opPackage[interfaceProviderIdx].c_str()); + } return Error::Ok; } diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp index 1f251aeaffa..c52ebf9f723 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp @@ -19,7 +19,8 @@ std::unique_ptr QnnBackendFactory::Create( QnnLogger* logger, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, - QnnDlcManager* qnn_dlc_manager) { + QnnDlcManager* qnn_dlc_manager, + std::unique_ptr&& backend_ptr) { auto backend_params = std::make_unique(); switch (options->backend_options()->backend_type()) { @@ -55,8 +56,9 @@ std::unique_ptr QnnBackendFactory::Create( QNN_EXECUTORCH_LOG_INFO( "use_fold_relu in htp_options: %d", htp_options->use_fold_relu()); } - backend_params->qnn_backend_ptr_ = - std::make_unique(implementation, logger); + // backend_params->qnn_backend_ptr_ = + // std::make_unique(implementation, logger); + backend_params->qnn_backend_ptr_ = std::move(backend_ptr); backend_params->qnn_device_ptr_ = std::make_unique( implementation, logger, options->soc_info(), htp_options); diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.h b/backends/qualcomm/runtime/backends/QnnBackendFactory.h index 3d78a36b9f0..ae4c3562284 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.h +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.h @@ -70,7 +70,8 @@ class QnnBackendFactory { QnnLogger* logger, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, - QnnDlcManager* qnn_dlc_manager); + QnnDlcManager* qnn_dlc_manager, + std::unique_ptr&& backend_ptr); }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/op_packages/.gitignore b/backends/qualcomm/runtime/op_packages/.gitignore new file mode 100644 index 00000000000..bdf319e79b0 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/.gitignore @@ -0,0 +1 @@ +**/build/ diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/Makefile b/backends/qualcomm/runtime/op_packages/TMANOpPackage/Makefile new file mode 100644 index 00000000000..2293076492b --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/Makefile @@ -0,0 +1,342 @@ +# check all setup prerequisites if the command goal is not clean +ifneq ($(MAKECMDGOALS),clean) +ifndef QNN_INCLUDE +$(info "INFO: Qnn include not explicitly defined, attempting to use QNN_SDK_ROOT if it is valid") +QNN_INCLUDE := $(QNN_SDK_ROOT)/include/QNN +endif +ifeq ($(wildcard $(QNN_INCLUDE)),) +$(error "ERROR: QNN_INCLUDE path is not set. QNN include paths must be set to obtain BE headers necessary to compile the package") +endif +ifndef QNN_TARGET_LIB +$(info "INFO: Qnn target not explicitly defined, attempting to use QNN_SDK_ROOT if it is valid") +QNN_TARGET_LIB := $(QNN_SDK_ROOT)/lib/aarch64-android +endif +ifeq ($(wildcard $(QNN_TARGET_LIB)),) +ifeq ($(MAKECMDGOALS),htp_aarch64) +$(error "ERROR: QNN_TARGET_LIB is needed to compile package for aarch64") +else ifeq ($(MAKECMDGOALS),all) +$(info "WARNING:QNN_TARGET_LIB may need to be defined to compile packages") +endif +endif + +ifndef HEXAGON_SDK_ROOT +$(error "ERROR: HEXAGON_SDK_ROOT is not set. Hexagon-SDK path must be set to the latest hexagon-sdk-x.y.z") +endif + +ifeq ($(wildcard $(HEXAGON_SDK_ROOT)),) +$(error "ERROR: HEXAGON_SDK_ROOT is not set correctly. Please set HEXAGON_SDK_ROOT to latest hexagon-sdk-X.Y.Z path") +endif + +HEXAGON_SDK_BASE := $(dir $(HEXAGON_SDK_ROOT)) + +$(info "HEXAGON_SDK_ROOT is [${HEXAGON_SDK_ROOT}]") +# Users should note that the tools version may change between hexagon sdk versions +# Following combination of SDK and Tool version is supported +HEXAGON_SDK_ROOT_V68 := $(HEXAGON_SDK_ROOT) +HEXAGON_SDK_ROOT_V69 := $(HEXAGON_SDK_ROOT) +HEXAGON_SDK_ROOT_V73 := $(HEXAGON_SDK_ROOT) +HEXAGON_SDK_ROOT_V75 := $(HEXAGON_SDK_ROOT) +HEXAGON_SDK_ROOT_V79 := $(HEXAGON_SDK_ROOT) +HEXAGON_SDK_ROOT_X86 := $(HEXAGON_SDK_ROOT) +HEXAGON_TOOLS_VERSION_V68 := 8.4.09 +HEXAGON_TOOLS_VERSION_V69 := 8.5.03 +HEXAGON_TOOLS_VERSION_V73 := 8.6.02 +HEXAGON_TOOLS_VERSION_V75 := 8.8.06 +HEXAGON_TOOLS_VERSION_V79 := 8.8.06 +#Updated to point to latest sdk to match with libQnnHtp.so +HEXAGON_TOOLS_VERSION_X86 := 8.8.06 + +ifndef ANDROID_NDK_ROOT +ifeq ($(MAKECMDGOALS),htp_aarch64) +$(error "ERROR: ANDROID_NDK_ROOT is not set. Android NDK path must be set to compile package for aarch64") +else ifeq ($(MAKECMDGOALS),all) +$(info "WARNING: ANDROID_NDK_ROOT is not set. Android NDK path must be set to compile package for aarch64") +endif +endif + +ifndef PACKAGE_NAME +export +PACKAGE_NAME := $(notdir $(shell pwd)) +$(info "INFO: No package name defined. Using current directory name: $(PACKAGE_NAME) as the package name") +endif + +WORK := build +SRC_DIR := src +OP_SRC_DIR := src/ops +OP_INCLUDE_DIR := ./include +OP_INCLUDES := $(wildcard $(OP_INCLUDE_DIR)/*.h) # user defined if any op specific headers are needed, add -I to common flags +LIBRARY_NAME := libQnn$(PACKAGE_NAME).so +SUPPORTED_TARGETS = x86_64-linux-clang hexagon-v68 hexagon-v69 hexagon-v73 hexagon-v75 hexagon-v79 aarch64-android + + +COMMON_CXX_FLAGS = -std=c++17 -I$(QNN_INCLUDE) -I$(OP_INCLUDE_DIR) -fPIC -Wall -Wreorder -Wno-missing-braces -Wno-unused-function -Wno-unused-variable +COMMON_CXX_FLAGS += -Werror -Wno-format -Wno-unused-command-line-argument -fvisibility=default -stdlib=libc++ +COMMON_CXX_FLAGS += -DQNN_API="__attribute__((visibility(\"default\")))" -D__QAIC_HEADER_EXPORT="__attribute__((visibility(\"default\")))" + +X86_LIBNATIVE_RELEASE_DIR := $(HEXAGON_SDK_ROOT_X86)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_X86)/Tools + +# Ensure hexagon sdk tool version can be retrieved +ifeq ($(wildcard $(X86_LIBNATIVE_RELEASE_DIR)/.),) +$(error "Cannot retrieve hexagon tools from: $(X86_LIBNATIVE_RELEASE_DIR). \ + \ + Please check that hexagon tools version is correct. Expected: $(HEXAGON_TOOLS_VERSION_X86)") +endif + +#Check tools for hexagon_v79 are present. +ifeq ($(MAKECMDGOALS),htp_v79) +ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V79)),) +$(error "ERROR: HEXAGON_SDK_ROOT_V79 is set incorrectly. Cannot retrieve $(HEXAGON_SDK_ROOT_V79)") +endif +endif + +#Check tools for hexagon_v75 are present. +ifeq ($(MAKECMDGOALS),htp_v75) +ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V75)),) +$(error "ERROR: HEXAGON_SDK_ROOT_V75 is set incorrectly. Cannot retrieve $(HEXAGON_SDK_ROOT_V75)") +endif +endif + +#Check tools for hexagon_v68 are present. +ifeq ($(MAKECMDGOALS),htp_v68) +ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V68)),) +$(error "ERROR: HEXAGON_SDK_ROOT_V68 is set incorrectly. Cannot retrieve $(HEXAGON_SDK_ROOT_V68)") +endif +endif + +ifeq ($(MAKECMDGOALS),htp_v69) +ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V69)),) +$(error "ERROR: HEXAGON_SDK_ROOT_V69 is set incorrectly. Cannot retrieve $(HEXAGON_SDK_ROOT_V69)") +endif +endif + +ifeq ($(MAKECMDGOALS),htp_v73) +ifeq ($(wildcard $(HEXAGON_SDK_ROOT_V73)),) +$(error "ERROR: HEXAGON_SDK_ROOT_V73 is set incorrectly. Cannot retrieve $(HEXAGON_SDK_ROOT_V73)") +endif +endif + +endif +OP_SOURCES = $(wildcard $(OP_SRC_DIR)/*.cpp) +OTHER_SOURCES = $(wildcard $(SRC_DIR)/*.cpp) +HFILES = $(wildcard $(QNN_INCLUDE)/*.h) +HFILES += $(wildcard $(QNN_INCLUDE)/HTP/*.h) +HFILES += $(wildcard $(QNN_INCLUDE)/HTP/core/*.h) +HFILES += $(OP_INCLUDES) +OP_OBJS = $(patsubst $(SRC_DIR)/%,%,$(patsubst %.cpp,%.o,$(OP_SOURCES))) +OTHER_OBJS = $(patsubst $(SRC_DIR)/%,%,$(patsubst %.cpp,%.o,$(OTHER_SOURCES))) + +#======= Assembly ======== +OP_SOURCES_ASM_X86 += $(wildcard $(OP_SRC_DIR)/x86_asm/*.S) +OP_OBJS_ASM_X86 += $(subst /x86_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_X86)))) +OP_SOURCES_ASM_V68 += $(wildcard $(OP_SRC_DIR)/v68_asm/*.S) +OP_OBJS_ASM_V68 += $(subst /v68_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_V68)))) +OP_SOURCES_ASM_V69 += $(wildcard $(OP_SRC_DIR)/v69_asm/*.S) +OP_OBJS_ASM_V69 += $(subst /v69_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_V69)))) +OP_SOURCES_ASM_V73 += $(wildcard $(OP_SRC_DIR)/v73_asm/*.S) +OP_OBJS_ASM_V73 += $(subst /v73_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_V73)))) +OP_SOURCES_ASM_V75 += $(wildcard $(OP_SRC_DIR)/v75_asm/*.S) +OP_OBJS_ASM_V75 += $(subst /v75_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_V75)))) +OP_SOURCES_ASM_V79 += $(wildcard $(OP_SRC_DIR)/v79_asm/*.S) +OP_OBJS_ASM_V79 += $(subst /v79_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_V79)))) +OP_SOURCES_ASM_ANDROID += $(wildcard $(OP_SRC_DIR)/android_asm/*.S) +OP_OBJS_ASM_ANDROID += $(subst /android_asm/,/,$(patsubst $(SRC_DIR)/%,%,$(patsubst %.S,%.o,$(OP_SOURCES_ASM_ANDROID)))) + +all: htp_v68 htp_x86 htp_aarch64 + +#============================================================================================================ +# Setup compiler, compiler instructions and linker for x86 +X86_CXX ?= clang++-9 +# Checking if clang++-9 is present. If not switch to clang++ +ifeq ($(shell $(X86_CXX) -v 2>&1 | grep -c "clang version"), 0) + X86_CXX := clang++ +endif +X86_LDFLAGS:= -Wl,--whole-archive -L$(X86_LIBNATIVE_RELEASE_DIR)/libnative/lib -lnative -Wl,--no-whole-archive -lpthread +X86_C_FLAGS := -D__HVXDBL__ -I$(X86_LIBNATIVE_RELEASE_DIR)/libnative/include -ffast-math -DUSE_OS_LINUX +X86_CXX_FLAGS = $(COMMON_CXX_FLAGS) $(X86_C_FLAGS) -fomit-frame-pointer -Wno-invalid-offsetof +linux_objs = +#============================================================================================================ +# Setup compiler, compiler instructions and linker for hexagon +HEXAGON_CXX_FLAGS := $(COMMON_CXX_FLAGS) -mhvx -mhvx-length=128B -mhmx -DUSE_OS_QURT -O2 -Wno-reorder -DPREPARE_DISABLED + +HEXAGON_CXX_FLAGS_V68 := $(HEXAGON_CXX_FLAGS) -mv68 -I$(HEXAGON_SDK_ROOT_V68)/rtos/qurt/computev68/include/qurt -I$(HEXAGON_SDK_ROOT_V68)/rtos/qurt/computev68/include/posix -I$(HEXAGON_SDK_ROOT_V68)/incs -I$(HEXAGON_SDK_ROOT_V68)/incs/stddef +HEXAGON_CXX_FLAGS_V69 := $(HEXAGON_CXX_FLAGS) -mv69 -I$(HEXAGON_SDK_ROOT_V69)/rtos/qurt/computev69/include/qurt -I$(HEXAGON_SDK_ROOT_V69)/rtos/qurt/computev69/include/posix -I$(HEXAGON_SDK_ROOT_V69)/incs -I$(HEXAGON_SDK_ROOT_V69)/incs/stddef +HEXAGON_CXX_FLAGS_V73 := $(HEXAGON_CXX_FLAGS) -mv73 -I$(HEXAGON_SDK_ROOT_V73)/rtos/qurt/computev73/include/qurt -I$(HEXAGON_SDK_ROOT_V73)/rtos/qurt/computev73/include/posix -I$(HEXAGON_SDK_ROOT_V73)/incs -I$(HEXAGON_SDK_ROOT_V73)/incs/stddef +HEXAGON_CXX_FLAGS_V75 := $(HEXAGON_CXX_FLAGS) -mv75 -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/qurt -I$(HEXAGON_SDK_ROOT_V75)/rtos/qurt/computev75/include/posix -I$(HEXAGON_SDK_ROOT_V75)/incs -I$(HEXAGON_SDK_ROOT_V75)/incs/stddef +HEXAGON_CXX_FLAGS_V79 := $(HEXAGON_CXX_FLAGS) -mv79 -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/qurt -I$(HEXAGON_SDK_ROOT_V79)/rtos/qurt/computev79/include/posix -I$(HEXAGON_SDK_ROOT_V79)/incs -I$(HEXAGON_SDK_ROOT_V79)/incs/stddef + +HEXAGON_CXX_V68 := $(HEXAGON_SDK_ROOT_V68)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V68)/Tools/bin/hexagon-clang++ +HEXAGON_CXX_V69 := $(HEXAGON_SDK_ROOT_V69)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V69)/Tools/bin/hexagon-clang++ +HEXAGON_CXX_V73 := $(HEXAGON_SDK_ROOT_V73)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V73)/Tools/bin/hexagon-clang++ +HEXAGON_CXX_V75 := $(HEXAGON_SDK_ROOT_V75)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V75)/Tools/bin/hexagon-clang++ +HEXAGON_CXX_V79 := $(HEXAGON_SDK_ROOT_V79)/tools/HEXAGON_Tools/$(HEXAGON_TOOLS_VERSION_V79)/Tools/bin/hexagon-clang++ + +HEX_LDFLAGS = +hexagon_objs = +#============================================================================================================ +# Setup compiler, compiler instructions and linker for aarch64 +AARCH64_C__FLAGS = -D__HVXDBL__ -I$(X86_LIBNATIVE_RELEASE_DIR)/libnative/include -ffast-math -DUSE_OS_LINUX -DANDROID +AARCH64_CXX_FLAGS = $(COMMON_CXX_FLAGS) $(AARCH64_C__FLAGS) -fomit-frame-pointer -Wno-invalid-offsetof -Wno-unused-variable -Wno-unused-parameter -Wno-missing-braces -Wno-sign-compare -Wno-unused-private-field -Wno-unused-variable -Wno-ignored-qualifiers -Wno-missing-field-initializers +ARM_CLANG_OPTS =--target=aarch64-none-linux-android21 --sysroot=$(ANDROID_NDK_ROOT)/toolchains/llvm/prebuilt/linux-x86_64/sysroot -stdlib=libc++ -static-libstdc++ +AARCH64_CXX = $(ANDROID_NDK_ROOT)/toolchains/llvm/prebuilt/linux-x86_64/bin/clang++ $(ARM_CLANG_OPTS) +AARCH64_LDFLAGS = -L$(QNN_TARGET_LIB) -lQnnHtp -lQnnHtpPrepare +aarch64_objs = +#============================================================================================================ +# Setup targets and goals + +htp_x86: X86_BUILD + +htp_v68: HEXAGON_BUILD_V68 + +htp_v69: HEXAGON_BUILD_V69 + +htp_v73: HEXAGON_BUILD_V73 + +htp_v75: HEXAGON_BUILD_V75 + +htp_v79: HEXAGON_BUILD_V79 + +htp_aarch64: AARCH64_BUILD + +AARCH64_BUILD: $(WORK)/aarch64-android/$(LIBRARY_NAME) + +HEXAGON_BUILD_V68: $(WORK)/hexagon-v68/$(LIBRARY_NAME) + +HEXAGON_BUILD_V69: $(WORK)/hexagon-v69/$(LIBRARY_NAME) + +HEXAGON_BUILD_V73: $(WORK)/hexagon-v73/$(LIBRARY_NAME) + +HEXAGON_BUILD_V75: $(WORK)/hexagon-v75/$(LIBRARY_NAME) + +HEXAGON_BUILD_V79: $(WORK)/hexagon-v79/$(LIBRARY_NAME) + +X86_BUILD: $(WORK)/x86_64-linux-clang/$(LIBRARY_NAME) + + +define build_objs = +ifneq ($(filter $(2),$(SUPPORTED_TARGETS)),) +$(2)_objs += $(foreach x,$(1),$(WORK)/$(2)/$(x)) +else +$$(error "Unknown target option provided: $(2): Supported targets are: $(SUPPORTED_TARGETS)") +endif +endef + +$(eval $(call build_objs,$(OTHER_OBJS),x86_64-linux-clang)) +$(eval $(call build_objs,$(OP_OBJS),x86_64-linux-clang)) +$(eval $(call build_objs,$(OP_OBJS_ASM_X86),x86_64-linux-clang)) +$(eval $(call build_objs,$(OTHER_OBJS),hexagon-v68)) +$(eval $(call build_objs,$(OP_OBJS),hexagon-v68)) +$(eval $(call build_objs,$(OP_OBJS_ASM_V68),hexagon-v68)) +$(eval $(call build_objs,$(OTHER_OBJS),hexagon-v69)) +$(eval $(call build_objs,$(OP_OBJS),hexagon-v69)) +$(eval $(call build_objs,$(OP_OBJS_ASM_V69),hexagon-v69)) +$(eval $(call build_objs,$(OTHER_OBJS),hexagon-v73)) +$(eval $(call build_objs,$(OP_OBJS),hexagon-v73)) +$(eval $(call build_objs,$(OP_OBJS_ASM_V73),hexagon-v73)) +$(eval $(call build_objs,$(OTHER_OBJS),hexagon-v75)) +$(eval $(call build_objs,$(OP_OBJS),hexagon-v75)) +$(eval $(call build_objs,$(OP_OBJS_ASM_V75),hexagon-v75)) +$(eval $(call build_objs,$(OTHER_OBJS),hexagon-v79)) +$(eval $(call build_objs,$(OP_OBJS),hexagon-v79)) +$(eval $(call build_objs,$(OP_OBJS_ASM_V75),hexagon-v79)) +$(eval $(call build_objs,$(OTHER_OBJS),aarch64-android)) +$(eval $(call build_objs,$(OP_OBJS),aarch64-android)) +$(eval $(call build_objs,$(OP_OBJS_ASM_ANDROID),aarch64-android)) + +# x86 +$(WORK)/x86_64-linux-clang $(WORK)/hexagon-v68 $(WORK)/hexagon-v69 $(WORK)/hexagon-v73 $(WORK)/hexagon-v75 $(WORK)/hexagon-v79 $(WORK)/aarch64-android: + @mkdir -p $@/ops + +$(WORK)/x86_64-linux-clang/%.o: $(SRC_DIR)/%.cpp | $(WORK)/x86_64-linux-clang + $(X86_CXX) $(X86_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/x86_64-linux-clang/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/x86_64-linux-clang + $(X86_CXX) $(X86_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/x86_64-linux-clang/ops/%.o: $(OP_SRC_DIR)/x86_asm/%.S | $(WORK)/x86_64-linux-clang + $(X86_CXX) $(X86_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/x86_64-linux-clang/$(LIBRARY_NAME): $(x86_64-linux-clang_objs) | $(HFILES) + $(X86_CXX) -fPIC -std=c++17 -g -shared -o $@ $^ $(X86_LDFLAGS) + +# v68 +$(WORK)/hexagon-v68/%.o: $(SRC_DIR)/%.cpp | $(WORK)/hexagon-v68 + $(HEXAGON_CXX_V68) $(HEXAGON_CXX_FLAGS_V68) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v68/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v68 + $(HEXAGON_CXX_V68) $(HEXAGON_CXX_FLAGS_V68) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v68/ops/%.o: $(OP_SRC_DIR)/v68_asm/%.S | $(WORK)/hexagon-v68 + $(HEXAGON_CXX_V68) $(HEXAGON_CXX_FLAGS_V68) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v68/$(LIBRARY_NAME): $(hexagon-v68_objs) | $(HFILES) + $(HEXAGON_CXX_V68) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS) + +# v69 +$(WORK)/hexagon-v69/%.o: $(SRC_DIR)/%.cpp | $(WORK)/hexagon-v69 + $(HEXAGON_CXX_V69) $(HEXAGON_CXX_FLAGS_V69) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v69/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v69 + $(HEXAGON_CXX_V69) $(HEXAGON_CXX_FLAGS_V69) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v69/ops/%.o: $(OP_SRC_DIR)/v69_asm/%.S | $(WORK)/hexagon-v69 + $(HEXAGON_CXX_V69) $(HEXAGON_CXX_FLAGS_V69) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v69/$(LIBRARY_NAME): $(hexagon-v69_objs) | $(HFILES) + $(HEXAGON_CXX_V69) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS) + +# v73 +$(WORK)/hexagon-v73/%.o: $(SRC_DIR)/%.cpp | $(WORK)/hexagon-v73 + $(HEXAGON_CXX_V73) $(HEXAGON_CXX_FLAGS_V73) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v73/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v73 + $(HEXAGON_CXX_V73) $(HEXAGON_CXX_FLAGS_V73) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v73/ops/%.o: $(OP_SRC_DIR)/v73_asm/%.S | $(WORK)/hexagon-v73 + $(HEXAGON_CXX_V73) $(HEXAGON_CXX_FLAGS_V73) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v73/$(LIBRARY_NAME): $(hexagon-v73_objs) | $(HFILES) + $(HEXAGON_CXX_V73) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS) + +#v75 +$(WORK)/hexagon-v75/%.o: $(SRC_DIR)/%.cpp | $(WORK)/hexagon-v75 + $(HEXAGON_CXX_V75) $(HEXAGON_CXX_FLAGS_V75) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v75/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v75 + $(HEXAGON_CXX_V75) $(HEXAGON_CXX_FLAGS_V75) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v75/ops/%.o: $(OP_SRC_DIR)/v75_asm/%.S | $(WORK)/hexagon-v75 + $(HEXAGON_CXX_V75) $(HEXAGON_CXX_FLAGS_V75) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v75/$(LIBRARY_NAME): $(hexagon-v75_objs) | $(HFILES) + $(HEXAGON_CXX_V75) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS) + +#v79 +$(WORK)/hexagon-v79/%.o: $(SRC_DIR)/%.cpp | $(WORK)/hexagon-v79 + $(HEXAGON_CXX_V79) $(HEXAGON_CXX_FLAGS_V79) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v79/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/hexagon-v79 + $(HEXAGON_CXX_V79) $(HEXAGON_CXX_FLAGS_V79) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v79/ops/%.o: $(OP_SRC_DIR)/v79_asm/%.S | $(WORK)/hexagon-v79 + $(HEXAGON_CXX_V79) $(HEXAGON_CXX_FLAGS_V79) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/hexagon-v79/$(LIBRARY_NAME): $(hexagon-v79_objs) | $(HFILES) + $(HEXAGON_CXX_V79) -fPIC -std=c++17 -g -shared -o $@ $^ $(HEX_LDFLAGS) + +# aarch64 +$(WORK)/aarch64-android/%.o: $(SRC_DIR)/%.cpp | $(WORK)/aarch64-android + $(AARCH64_CXX) $(AARCH64_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/aarch64-android/ops/%.o: $(OP_SRC_DIR)/%.cpp | $(WORK)/aarch64-android + $(AARCH64_CXX) $(AARCH64_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/aarch64-android/ops/%.o: $(OP_SRC_DIR)/android_asm/%.S | $(WORK)/aarch64-android + $(AARCH64_CXX) $(AARCH64_CXX_FLAGS) -DTHIS_PKG_NAME=$(PACKAGE_NAME) -MMD -c $< -o $@ + +$(WORK)/aarch64-android/$(LIBRARY_NAME): $(aarch64-android_objs) | $(HFILES) + $(AARCH64_CXX) -fPIC -std=c++17 -g -shared -o $@ $^ $(AARCH64_LDFLAGS) + +clean: + -rm -rf $(WORK) + +.PHONY: all clean diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/config/TMANOpPackageHtp.xml b/backends/qualcomm/runtime/op_packages/TMANOpPackage/config/TMANOpPackageHtp.xml new file mode 100644 index 00000000000..ed69b7a34e0 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/config/TMANOpPackageHtp.xml @@ -0,0 +1,110 @@ + + + + + TMANLinear + + Matrix multiplication through T-MAC paradigm + + + + in + + Input tensor + + true + QNN_DATATYPE_UFIXED_POINT_16 + + 4D + NHWC + + + + + qweight + + qweight matrix + + true + QNN_DATATYPE_INT_32 + + 4D + NHWC + + true + + + + scales + + scale factor + + true + QNN_DATATYPE_INT_32 + + 4D + NHWC + + true + + + + out + + Output tensor + + true + QNN_DATATYPE_UFIXED_POINT_16 + + 4D + NHWC + + + + + scratch_buffer + + Output tensor + + true + QNN_DATATYPE_INT_32 + + 4D + NHWC + + + + + group_size + true + QNN_DATATYPE_INT_32 + + SCALAR + + 64 + + + + bits + true + QNN_DATATYPE_INT_32 + + SCALAR + + 2 + + + + symmetric + true + QNN_DATATYPE_INT_32 + + SCALAR + + 0 + + + HTP + + + diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/include/hvx_funcs.h b/backends/qualcomm/runtime/op_packages/TMANOpPackage/include/hvx_funcs.h new file mode 100644 index 00000000000..4816048d05f --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/include/hvx_funcs.h @@ -0,0 +1,752 @@ +#pragma once + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "QnnOpPackage.h" +#include "HTP/core/simple_reg.h" + +#include +#include + +#define UNUSED(x) (void)(x) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +static inline int32_t _fp32_to_bits(float x) +{ + union { + float f; + int32_t i; + } u; + u.f = x; + return u.i; +} + +static inline int16_t _fp16_to_bits(const __fp16 *x) +{ + union { + __fp16 f; + int16_t i; + } u; + u.f = *x; + return u.i; +} + +#ifndef _HVX_INTERNAL_H +#define _HVX_INTERNAL_H + +#define VLEN 128 +#define vmem(A) *((HVX_Vector *)(A)) +#define HVX_INLINE_ALWAYS inline __attribute__((unused,always_inline)) +static HVX_INLINE_ALWAYS void l2fetch(const void *p, uint32_t stride, + uint32_t width, uint32_t height, + uint32_t dir) +{ +#ifdef __hexagon__ + uint64_t control = HEXAGON_V64_CREATE_H(dir, stride, width, height); + __asm__ __volatile__ (" l2fetch(%0,%1) " : :"r"(p),"r"(control)); +#endif +} + +#endif /* _HVX_INTERNAL_H */ + +template +inline typename std::enable_if_t::value && std::is_same::value, int> +hvx_lut_ctor(int32_t GemmK, int32_t GemmN, const XType *x, LType *l, float *ls, float *lb) +{ + UNUSED(GemmN); + + const int32_t Q = GemmK / g; + + const int32_t q_act_group_size = (ActGroupSize < 0) ? (Q / -ActGroupSize) : (ActGroupSize / g); + const int32_t q_group_size = (GroupSize == 0) ? Q : (GroupSize / g); + + constexpr int32_t lut_size = 16; + constexpr float max_int16 = 32767.0f; + + constexpr int32_t VecQ = VLEN / sizeof(XType); + + const HVX_Vector zero_vec = Q6_V_vzero(); + const HVX_Vector ones_vec = Q6_Vh_vsplat_R(0x3C00); // 1.0f + const HVX_Vector abs_mask = Q6_Vh_vsplat_R(0x7FFF); + + // lut_bias is stored in fp16 if ZeroPoint is true to avoid conversion during hvx_tbl + using lb_t = typename std::conditional::type; + lb_t *lb_p = reinterpret_cast(lb); + + XType __attribute__((aligned(VLEN))) tmp_buf[VLEN / sizeof(XType)]; + + HVX_Vector lb_val_vec = zero_vec; + for (int32_t group_q = 0; group_q < Q; group_q += q_act_group_size) + { + // Compute LUT scales + HVX_Vector ls_val_vec = zero_vec; + for (int32_t q = 0; q < q_act_group_size; q += VecQ) + { + const XType *x_base = x + (group_q + q) * g; + + HVX_Vector x0 = vmem(x_base); + HVX_Vector x1 = vmem(x_base + VecQ); + HVX_Vector x2 = vmem(x_base + VecQ * 2); + HVX_Vector x3 = vmem(x_base + VecQ * 3); + + // Transpose (64, 4) -> (4, 64) + // 16-bit + HVX_VectorPair x01 = Q6_W_vdeal_VVR(x1, x0, -2); + HVX_VectorPair x23 = Q6_W_vdeal_VVR(x3, x2, -2); + // 32-bit + HVX_VectorPair x02 = Q6_W_vdeal_VVR(Q6_V_lo_W(x23), Q6_V_lo_W(x01), -2); + HVX_VectorPair x13 = Q6_W_vdeal_VVR(Q6_V_hi_W(x23), Q6_V_hi_W(x01), -2); + + // abs + // Vhf_vabs_Vhf works on simulator, but not on device (test on 8 gen 3) + HVX_Vector x0_abs = Q6_V_vand_VV(Q6_V_lo_W(x02), abs_mask); + HVX_Vector x1_abs = Q6_V_vand_VV(Q6_V_lo_W(x13), abs_mask); + HVX_Vector x2_abs = Q6_V_vand_VV(Q6_V_hi_W(x02), abs_mask); + HVX_Vector x3_abs = Q6_V_vand_VV(Q6_V_hi_W(x13), abs_mask); + + // sum + HVX_Vector x01_abs = Q6_Vqf16_vadd_VhfVhf(x0_abs, x1_abs); + HVX_Vector x23_abs = Q6_Vqf16_vadd_VhfVhf(x2_abs, x3_abs); + HVX_Vector sum_abs = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vqf16(x01_abs, x23_abs)); + + ls_val_vec = Q6_Vhf_vmax_VhfVhf(ls_val_vec, sum_abs); + } + + // self_max + for (int32_t i = VLEN / 2; i >= 2; i >>= 1) + { + ls_val_vec = Q6_Vhf_vmax_VhfVhf(ls_val_vec, Q6_V_vlalign_VVR(ls_val_vec, zero_vec, i)); + } + vmem(tmp_buf) = ls_val_vec; + float ls_val = (float)tmp_buf[VLEN / 2 - 1] / max_int16; + ls[group_q / q_act_group_size] = ls_val; + + float tls_val = ls_val ? 1.0f / ls_val : 0.0f; + HVX_Vector tls_val_qf32 = Q6_Vqf32_vadd_VsfVsf(Q6_V_vsplat_R(_fp32_to_bits(tls_val)), zero_vec); + + // Construct LUT + // qf16 is not enough for accumulation + for (int32_t q = 0; q < q_act_group_size; q += VecQ) + { + const XType *x_base = x + (group_q + q) * g; + + HVX_Vector x0 = vmem(x_base); + HVX_Vector x1 = vmem(x_base + VecQ); + HVX_Vector x2 = vmem(x_base + VecQ * 2); + HVX_Vector x3 = vmem(x_base + VecQ * 3); + + // Transpose (64, 4) -> (4, 64) + HVX_VectorPair x01 = Q6_W_vdeal_VVR(x1, x0, -2); + HVX_VectorPair x23 = Q6_W_vdeal_VVR(x3, x2, -2); + HVX_VectorPair x02 = Q6_W_vdeal_VVR(Q6_V_lo_W(x23), Q6_V_lo_W(x01), -2); + HVX_VectorPair x13 = Q6_W_vdeal_VVR(Q6_V_hi_W(x23), Q6_V_hi_W(x01), -2); + + // Instead of add zero, multiply by one is more accurate + HVX_VectorPair x0_qf32 = Q6_Wqf32_vmpy_VhfVhf(Q6_V_lo_W(x02), ones_vec); + HVX_VectorPair x1_qf32 = Q6_Wqf32_vmpy_VhfVhf(Q6_V_lo_W(x13), ones_vec); + HVX_VectorPair x2_qf32 = Q6_Wqf32_vmpy_VhfVhf(Q6_V_hi_W(x02), ones_vec); + HVX_VectorPair x3_qf32 = Q6_Wqf32_vmpy_VhfVhf(Q6_V_hi_W(x13), ones_vec); + + HVX_Vector l_tmp_lo_qf32[lut_size]; + HVX_Vector l_tmp_hi_qf32[lut_size]; +#pragma unroll + for (int32_t i = 1; i < lut_size; i += 2) + { + if (i & 0b0010) { + l_tmp_lo_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(x0_qf32), Q6_V_lo_W(x1_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(x0_qf32), Q6_V_hi_W(x1_qf32)); + } else { + l_tmp_lo_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(x0_qf32), Q6_V_lo_W(x1_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(x0_qf32), Q6_V_hi_W(x1_qf32)); + } + if (i & 0b0100) { + l_tmp_lo_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(l_tmp_lo_qf32[i], Q6_V_lo_W(x2_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(l_tmp_hi_qf32[i], Q6_V_hi_W(x2_qf32)); + } else { + l_tmp_lo_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(l_tmp_lo_qf32[i], Q6_V_lo_W(x2_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(l_tmp_hi_qf32[i], Q6_V_hi_W(x2_qf32)); + } + if (i & 0b1000) { + l_tmp_lo_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(l_tmp_lo_qf32[i], Q6_V_lo_W(x3_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vadd_Vqf32Vqf32(l_tmp_hi_qf32[i], Q6_V_hi_W(x3_qf32)); + } else { + l_tmp_lo_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(l_tmp_lo_qf32[i], Q6_V_lo_W(x3_qf32)); + l_tmp_hi_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(l_tmp_hi_qf32[i], Q6_V_hi_W(x3_qf32)); + } + } + + // Mirror consolidation +#pragma unroll + for (int32_t i = 0; i < lut_size; i += 2) + { + // NOT the sign bit won't work + l_tmp_lo_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(zero_vec, l_tmp_lo_qf32[lut_size - 1 - i]); + l_tmp_hi_qf32[i] = Q6_Vqf32_vsub_Vqf32Vqf32(zero_vec, l_tmp_hi_qf32[lut_size - 1 - i]); + } + + lb_val_vec = Q6_Vqf32_vadd_Vqf32Vqf32(lb_val_vec, l_tmp_lo_qf32[lut_size - 1]); + lb_val_vec = Q6_Vqf32_vadd_Vqf32Vqf32(lb_val_vec, l_tmp_hi_qf32[lut_size - 1]); + + // Quant LUT + HVX_Vector l_tmp[lut_size]; +#pragma unroll + for (int32_t i = 0; i < lut_size; i += 1) + { + HVX_Vector l_tmp_lo = Q6_Vqf32_vmpy_Vqf32Vqf32(l_tmp_lo_qf32[i], tls_val_qf32); + HVX_Vector l_tmp_hi = Q6_Vqf32_vmpy_Vqf32Vqf32(l_tmp_hi_qf32[i], tls_val_qf32); + l_tmp_lo = Q6_Vw_equals_Vsf(Q6_Vsf_equals_Vqf32(l_tmp_lo)); + l_tmp_hi = Q6_Vw_equals_Vsf(Q6_Vsf_equals_Vqf32(l_tmp_hi)); + l_tmp[i] = Q6_Vh_vsat_VwVw(l_tmp_hi, l_tmp_lo); + } + + // Shuffle and store: + // only need to shuffle to 32-bit, + // as even and odd LUTs are interleaved + HVX_VectorPair l_pa[lut_size / 2]; + HVX_VectorPair l_pb[lut_size / 2]; + + // 32-bit, interval=1 +#pragma unroll + for (int32_t i = 0; i < lut_size; i += 2) + { + l_pa[i / 2] = Q6_W_vshuff_VVR(l_tmp[i + 1], l_tmp[i], -4); + } + // 64-bit, interval=2 +#pragma unroll + for (int32_t i = 0; i < lut_size / 2; i += 2) + { + l_pb[i + 0] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pa[i + 1]), Q6_V_lo_W(l_pa[i + 0]), -8); + l_pb[i + 1] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pa[i + 1]), Q6_V_hi_W(l_pa[i + 0]), -8); + } + // 128-bit, interval=4 +#pragma unroll + for (int32_t i = 0; i < lut_size / 2; i += 4) + { + l_pa[i + 0] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pb[i + 2]), Q6_V_lo_W(l_pb[i + 0]), -16); + l_pa[i + 1] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pb[i + 2]), Q6_V_hi_W(l_pb[i + 0]), -16); + l_pa[i + 2] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pb[i + 3]), Q6_V_lo_W(l_pb[i + 1]), -16); + l_pa[i + 3] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pb[i + 3]), Q6_V_hi_W(l_pb[i + 1]), -16); + } + // 256-bit, interval=8 +#pragma unroll + for (int32_t i = 0; i < lut_size / 2; i += 8) + { + l_pb[i + 0] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pa[i + 4]), Q6_V_lo_W(l_pa[i + 0]), -32); + l_pb[i + 1] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pa[i + 4]), Q6_V_hi_W(l_pa[i + 0]), -32); + l_pb[i + 2] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pa[i + 5]), Q6_V_lo_W(l_pa[i + 1]), -32); + l_pb[i + 3] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pa[i + 5]), Q6_V_hi_W(l_pa[i + 1]), -32); + l_pb[i + 4] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pa[i + 6]), Q6_V_lo_W(l_pa[i + 2]), -32); + l_pb[i + 5] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pa[i + 6]), Q6_V_hi_W(l_pa[i + 2]), -32); + l_pb[i + 6] = Q6_W_vshuff_VVR(Q6_V_lo_W(l_pa[i + 7]), Q6_V_lo_W(l_pa[i + 3]), -32); + l_pb[i + 7] = Q6_W_vshuff_VVR(Q6_V_hi_W(l_pa[i + 7]), Q6_V_hi_W(l_pa[i + 3]), -32); + } + // write back + LType *l_base = l + (group_q + q) * lut_size; +#pragma unroll + for (int32_t i = 0; i < lut_size / 2; i += 1) + { + vmem(l_base + (i * 2) * VLEN / sizeof(LType)) = Q6_V_lo_W(l_pb[i]); + vmem(l_base + (i * 2 + 1) * VLEN / sizeof(LType)) = Q6_V_hi_W(l_pb[i]); + } + if ((q_group_size >= VecQ) && ((group_q + q) % q_group_size == (q_group_size - VecQ))) + { + // self_sum + for (int32_t i = VLEN / 2; i >= 4; i >>= 1) + { + lb_val_vec = Q6_Vqf32_vadd_Vqf32Vqf32(lb_val_vec, Q6_V_vlalign_VVR(lb_val_vec, zero_vec, i)); + } + vmem(tmp_buf) = Q6_Vsf_equals_Vqf32(lb_val_vec); + lb_p[(group_q + q) / q_group_size] = -((const float *)tmp_buf)[VLEN / 4 - 1] * 0.5f; + lb_val_vec = zero_vec; + } + if (q_group_size < VecQ) + { + // self_sum with VecQ/q_group_size groups + const int32_t sum_len = VLEN / (VecQ / q_group_size); + for (int32_t i = sum_len / 2; i >= 4; i >>= 1) + { + lb_val_vec = Q6_Vqf32_vadd_Vqf32Vqf32(lb_val_vec, Q6_V_vlalign_VVR(lb_val_vec, zero_vec, i)); + } + vmem(tmp_buf) = Q6_Vsf_equals_Vqf32(lb_val_vec); + for (int32_t i = VLEN / 4 - 1; i >= 0; i -= sum_len / 4) + { + lb_p[(group_q + q) / q_group_size + i / (sum_len / 4)] = -((const float *)tmp_buf)[i] * 0.5f; + } + lb_val_vec = zero_vec; + } + } // q_act_group_size + } + + return 0; +} + +// For fine-grained group-wise quantization (GPTQ) +template +inline typename std::enable_if_t::value && std::is_same::value && std::is_same::value && (GroupSize > 0), int> +hvx_tbl(int32_t GemmM, int32_t GemmK, int32_t GemmN, const LType *l, const float *ls, const float *lb, const uint8_t *w, const XType *s, CType *c) +{ + UNUSED(GemmN); + + // Number of elements in a single 4bit pack + constexpr int8_t mask_4bit = 0b1111; + constexpr int8_t shift_len = 4; + + const HVX_Vector mask_vec = Q6_Vb_vsplat_R(mask_4bit); + const HVX_Vector ones_vec = Q6_Vh_vsplat_R(0x3C00); // 1.0f + + constexpr int32_t lut_size = 16; + constexpr int32_t lut_bytes = lut_size * sizeof(LType); + // K, M -> Q, P lookup Q tables with P indices + // Q = K / g, P = M * Bits + // x_shape: (Q / TileQ, TileQ / VecQ, VecQ, lut_size) = (Q, lut_size), elem_size = 2 bytes + // w_shape: (P / TileP, Q / TileQ, TileP / VecP, TileQ / VecQ, VecQ, VecP) indices, elem_size = g / 8 = 0.5 bytes + // indices of two VecQ are zipped into one Vector + const int32_t Q = GemmK / g; + const int32_t P = GemmM * Bits; + + constexpr int32_t q_group_size = GroupSize / g; + constexpr int32_t q_act_group_size = ActGroupSize / g; + // compute block size + constexpr int32_t cmp_blk_size = MIN(GroupSize / g, ActGroupSize / g); + + constexpr int32_t VecQ = VLEN / lut_bytes; + constexpr int32_t VecP = VLEN / sizeof(uint8_t); + + constexpr int32_t TileQ = TileK / g; + // TileP = ThreadP + const int32_t TileP = P; + + // In practice, for int16_t activation, group size < act group size (not required) + static_assert((ActGroupSize % GroupSize == 0) || (GroupSize % ActGroupSize) == 0, "ActGroupSize or GroupSize must be divisible by the other"); + // Implies that GroupSize % 16 == 0 + static_assert((cmp_blk_size % VecQ == 0), "cmp_blk_size must be divisible by VecQ"); + static_assert((TileQ % cmp_blk_size == 0), "TileQ must be divisible by cmp_blk_size"); // this requirement is unnecessary. however, i enforce it to simplify the code + static_assert((Bits <= 4 && Bits >= 2), "2 <= Bits <= 4 is required"); // Bits == 1 also works. Just need to multiply lb by 2 + + // Step.1: TABLE TOOKUP + HVX_Vector lvec_arr[TileQ / VecQ]; + + memset(c, 0, sizeof(CType) * TileP); + + for (int32_t tile_q = 0; tile_q < Q; tile_q += TileQ) + { +#pragma unroll + for (int32_t vec_q = 0; vec_q < TileQ; vec_q += VecQ) + { + lvec_arr[vec_q / VecQ] = vmem(l + (tile_q + vec_q) * lut_size); + } + + // we can't prefetch scales here, as the size is too large + // e.g., 64KB for m=4096, group_size=64, TileQ=64, float16, ZeroPoint=true + // prefetch size = s_l2fetch_p / Bits * sizeof(XType) * TileQ / q_group_size * (1 + ZeroPoint) + // e.g., 2KB for the same parameters above + constexpr int32_t s_l2fetch_p = VecP * Bits; + constexpr int32_t s_l2fetch_size = (s_l2fetch_p / Bits) * (TileQ / q_group_size) * (1 + ZeroPoint); + constexpr int32_t s_l2fetch_one = (s_l2fetch_p / Bits) * (1 + ZeroPoint); + + const uint8_t *w_tile_base = w + tile_q * TileP * g / 8; + const XType *s_tile_base = s + (tile_q / q_group_size) * (TileP / Bits) * (1 + ZeroPoint); + + if (!WeightsInVTCM) { + l2fetch(s_tile_base + s_l2fetch_one, VLEN, VLEN, (s_l2fetch_size - s_l2fetch_one) * sizeof(XType) / VLEN, 0); + } + + if (tile_q + TileQ < VecQ) + { + constexpr int32_t l1cache_line = 64; + for (int i = (tile_q + TileQ) / q_act_group_size; i < (tile_q + TileQ * 2) / q_act_group_size; i += l1cache_line / sizeof(float)) + { + Q6_dcfetch_A((void *)(ls + i)); + } + for (int i = (tile_q + TileQ) / q_group_size; i < (tile_q + TileQ * 2) / q_group_size; i += l1cache_line / sizeof(float)) + { + Q6_dcfetch_A((void *)(lb + i)); + } + } + +#pragma unroll(Bits) + for (int32_t vec_p = 0; vec_p < TileP; vec_p += VecP) + { + // qf32 + // we should guarantee all these belong to the same bits during preprocessing + // i.e., VecBits = VecP = VecC * 4 + HVX_Vector c_vec_0 = vmem(c + (vec_p + 0)); + HVX_Vector c_vec_1 = vmem(c + (vec_p + 32)); + HVX_Vector c_vec_2 = vmem(c + (vec_p + 64)); + HVX_Vector c_vec_3 = vmem(c + (vec_p + 96)); + + // int32_t + HVX_VectorPair c_vec_lo; + HVX_VectorPair c_vec_hi; + + const uint8_t *w_base = w_tile_base + vec_p * TileQ * g / 8; + const XType *s_base = s_tile_base + vec_p / (VecP * Bits) * (TileQ / q_group_size) * VecP * (1 + ZeroPoint); + if (!WeightsInVTCM) + { + if (vec_p + VecP < TileP) + { + l2fetch(w_base + VecP * TileQ * g / 8, VecP, VecP, TileQ * g / 8, 0); + if (vec_p % s_l2fetch_p == 0) + { + l2fetch(s_base + s_l2fetch_size, VLEN, VLEN, s_l2fetch_size * sizeof(XType) / VLEN, 0); + } + } + } + +#pragma unroll + for (int32_t vec_q = 0; vec_q < TileQ; vec_q += VecQ) + { + HVX_Vector w_vec_lo = vmem(w_base + vec_q * VecP * g / 8 + 0); + HVX_Vector w_vec_hi = vmem(w_base + vec_q * VecP * g / 8 + VLEN); + + HVX_Vector w_vec_lo_bo = Q6_V_vand_VV(w_vec_lo, mask_vec); // Q = 0 + HVX_Vector w_vec_hi_bo = Q6_V_vand_VV(w_vec_hi, mask_vec); // Q = 2 + HVX_Vector w_vec_lo_to = Q6_Vh_vasr_VhR(w_vec_lo, shift_len); // Q = 1 + HVX_Vector w_vec_hi_to = Q6_Vh_vasr_VhR(w_vec_hi, shift_len); // Q = 3 + + // int16_t + // c_vec_lo_bo_lo: even bytes of w_vec_lo_bo, c_vec_lo_bo_hi: odd bytes of w_vec_lo_bo + HVX_VectorPair c_vec_lo_bo = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_lo_bo, lvec_arr[vec_q / VecQ], 0); // Q = 0, even lo + HVX_VectorPair c_vec_hi_bo = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_hi_bo, lvec_arr[vec_q / VecQ], 1); // Q = 2, even hi + HVX_VectorPair c_vec_lo_to = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_lo_to, lvec_arr[vec_q / VecQ], 2); // Q = 1, odd lo + HVX_VectorPair c_vec_hi_to = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_hi_to, lvec_arr[vec_q / VecQ], 3); // Q = 3, odd hi + + // After unroll, the boolean variables should be broadcasted to constexpr and the branches will be expanded + const bool cmp_blk_head = (vec_q % cmp_blk_size == 0); + const bool cmp_blk_tail = (vec_q % cmp_blk_size == (cmp_blk_size - VecQ)); + const bool q_group_tail = (vec_q % q_group_size == (q_group_size - VecQ)); + + // int32_t + // c_vec_lo: even bytes of w_vec + // c_vec_hi: odd bytes of w_vec + // TAG0: Here widening add will perform a 2x64 transpose + if (cmp_blk_head) + { + // reset int32_t sum + c_vec_lo = Q6_Ww_vadd_VhVh(Q6_V_lo_W(c_vec_lo_bo), Q6_V_lo_W(c_vec_hi_bo)); + c_vec_hi = Q6_Ww_vadd_VhVh(Q6_V_hi_W(c_vec_lo_bo), Q6_V_hi_W(c_vec_hi_bo)); + } + else + { + c_vec_lo = Q6_Ww_vaddacc_WwVhVh(c_vec_lo, Q6_V_lo_W(c_vec_lo_bo), Q6_V_lo_W(c_vec_hi_bo)); + c_vec_hi = Q6_Ww_vaddacc_WwVhVh(c_vec_hi, Q6_V_hi_W(c_vec_lo_bo), Q6_V_hi_W(c_vec_hi_bo)); + } + c_vec_lo = Q6_Ww_vaddacc_WwVhVh(c_vec_lo, Q6_V_lo_W(c_vec_lo_to), Q6_V_lo_W(c_vec_hi_to)); + c_vec_hi = Q6_Ww_vaddacc_WwVhVh(c_vec_hi, Q6_V_hi_W(c_vec_lo_to), Q6_V_hi_W(c_vec_hi_to)); + + // qf32 + if (cmp_blk_tail) + { + const XType *s_ptr = s_base + (vec_q / q_group_size) * VecP * (1 + ZeroPoint); + // for fp16 scales, 64 elements per vector + HVX_Vector s_vec_lo_fp16 = vmem(s_ptr); + HVX_Vector s_vec_hi_fp16 = vmem(s_ptr + VLEN / sizeof(XType)); + + HVX_VectorPair s_vec_lo = Q6_Wqf32_vmpy_VhfVhf(s_vec_lo_fp16, ones_vec); + HVX_VectorPair s_vec_hi = Q6_Wqf32_vmpy_VhfVhf(s_vec_hi_fp16, ones_vec); + + HVX_Vector ls_vec = Q6_V_vsplat_R(_fp32_to_bits(ls[(tile_q + vec_q) / q_act_group_size])); + HVX_Vector lb_vec; + if (ZeroPoint) { + lb_vec = Q6_Vh_vsplat_R(_fp16_to_bits(reinterpret_cast(lb) + (tile_q + vec_q) / q_group_size)); + } else { + lb_vec = Q6_V_vsplat_R(_fp32_to_bits(lb[(tile_q + vec_q) / q_group_size])); + } + + // int32_t -> fp32 + // TODO: consider reordering for understanding: 0, 1, 2, 3 -> lo_W(lo), hi_W(lo), lo_W(hi), hi_W(hi) + HVX_Vector c_vec_0_sf = Q6_Vsf_equals_Vw(Q6_V_lo_W(c_vec_lo)); + HVX_Vector c_vec_1_sf = Q6_Vsf_equals_Vw(Q6_V_lo_W(c_vec_hi)); + HVX_Vector c_vec_2_sf = Q6_Vsf_equals_Vw(Q6_V_hi_W(c_vec_lo)); + HVX_Vector c_vec_3_sf = Q6_Vsf_equals_Vw(Q6_V_hi_W(c_vec_hi)); + + // * ls + HVX_Vector c_vec_0_qf32 = Q6_Vqf32_vmpy_VsfVsf(c_vec_0_sf, ls_vec); + HVX_Vector c_vec_1_qf32 = Q6_Vqf32_vmpy_VsfVsf(c_vec_1_sf, ls_vec); + HVX_Vector c_vec_2_qf32 = Q6_Vqf32_vmpy_VsfVsf(c_vec_2_sf, ls_vec); + HVX_Vector c_vec_3_qf32 = Q6_Vqf32_vmpy_VsfVsf(c_vec_3_sf, ls_vec); + + // + lb (lb = -1/2 partial sum) + // only add to b=1, and once for each weights quantization group + if (q_group_tail && (vec_p % (VecP * Bits) == VecP)) + { + // (c * ls + lb) * s + z * s * lb * 2 + // = (c * ls + lb + z * lb * 2) * s + // = (c * ls + (z * 2 + 1) * lb) * s + if (ZeroPoint) + { + HVX_Vector z_vec_lo_fp16 = vmem(s_ptr + VecP); + HVX_Vector z_vec_hi_fp16 = vmem(s_ptr + VecP + VLEN / sizeof(XType)); + + HVX_VectorPair zlb_vec_lo = Q6_Wqf32_vmpy_VhfVhf(z_vec_lo_fp16, lb_vec); + HVX_VectorPair zlb_vec_hi = Q6_Wqf32_vmpy_VhfVhf(z_vec_hi_fp16, lb_vec); + + c_vec_0_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_0_qf32, Q6_V_lo_W(zlb_vec_lo)); + c_vec_1_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_1_qf32, Q6_V_lo_W(zlb_vec_hi)); + c_vec_2_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_2_qf32, Q6_V_hi_W(zlb_vec_lo)); + c_vec_3_qf32 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_3_qf32, Q6_V_hi_W(zlb_vec_hi)); + } + else + { + c_vec_0_qf32 = Q6_Vqf32_vadd_Vqf32Vsf(c_vec_0_qf32, lb_vec); + c_vec_1_qf32 = Q6_Vqf32_vadd_Vqf32Vsf(c_vec_1_qf32, lb_vec); + c_vec_2_qf32 = Q6_Vqf32_vadd_Vqf32Vsf(c_vec_2_qf32, lb_vec); + c_vec_3_qf32 = Q6_Vqf32_vadd_Vqf32Vsf(c_vec_3_qf32, lb_vec); + } + } + + // * s + c_vec_0_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(c_vec_0_qf32, Q6_V_lo_W(s_vec_lo)); + c_vec_1_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(c_vec_1_qf32, Q6_V_lo_W(s_vec_hi)); + c_vec_2_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(c_vec_2_qf32, Q6_V_hi_W(s_vec_lo)); + c_vec_3_qf32 = Q6_Vqf32_vmpy_Vqf32Vqf32(c_vec_3_qf32, Q6_V_hi_W(s_vec_hi)); + + c_vec_0 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_0, c_vec_0_qf32); + c_vec_1 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_1, c_vec_1_qf32); + c_vec_2 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_2, c_vec_2_qf32); + c_vec_3 = Q6_Vqf32_vadd_Vqf32Vqf32(c_vec_3, c_vec_3_qf32); + } + } + + vmem(c + (vec_p + 0)) = c_vec_0; + vmem(c + (vec_p + 32)) = c_vec_1; + vmem(c + (vec_p + 64)) = c_vec_2; + vmem(c + (vec_p + 96)) = c_vec_3; + } + } + + return 0; +} + +// For BitNet +template +inline typename std::enable_if_t::value && std::is_same::value && std::is_same::value && GroupSize == 0, int> +hvx_tbl(int32_t GemmM, int32_t GemmK, int32_t GemmN, const LType *l, const float *ls, const float *lb, const uint8_t *w, const XType *s, CType *c) +{ + UNUSED(GemmN); + + // Number of elements in a single 4bit pack + constexpr int8_t mask_4bit = 0b1111; + constexpr int8_t shift_len = 4; + + const HVX_Vector mask_vec = Q6_Vb_vsplat_R(mask_4bit); + const HVX_Vector ones_vec = Q6_Vh_vsplat_R(0x3C00); // 1.0f + + constexpr int32_t lut_size = 16; + constexpr int32_t lut_bytes = lut_size * sizeof(LType); + // K, M -> Q, P lookup Q tables with P indices + // Q = K / g, P = M * Bits + // x_shape: (Q / TileQ, TileQ / VecQ, VecQ, lut_size) = (Q, lut_size), elem_size = 2 bytes + // w_shape: (P / TileP, Q / TileQ, TileP / VecP, TileQ / VecQ, VecQ, VecP) indices, elem_size = g / 8 = 0.5 bytes + // indices of two VecQ are zipped into one Vector + const int32_t Q = GemmK / g; + const int32_t P = GemmM * Bits; + + constexpr int32_t VecQ = VLEN / lut_bytes; + constexpr int32_t VecP = VLEN / sizeof(uint8_t); + + constexpr int32_t TileQ = TileK / g; + // TileP = ThreadP + const int32_t TileP = P; + + // In practice, for int16_t activation, group size < act group size (not required) + static_assert((ActGroupSize == -1), "For BitNet model, only per-tensor quantization is supported"); + static_assert(!ZeroPoint, "For BitNet model, the quantization should be symmetric"); + // Implies that GroupSize % 16 == 0 + static_assert((Bits <= 4 && Bits >= 2), "2 <= Bits <= 4 is required"); // Bits == 1 also works. Just need to multiply lb by 2 + + // Step.1: TABLE TOOKUP + HVX_Vector lvec_arr[TileQ / VecQ]; + + memset(c, 0, sizeof(CType) * TileP); + + for (int32_t tile_q = 0; tile_q < Q; tile_q += TileQ) + { +#pragma unroll + for (int32_t vec_q = 0; vec_q < TileQ; vec_q += VecQ) + { + lvec_arr[vec_q / VecQ] = vmem(l + (tile_q + vec_q) * lut_size); + } + + const uint8_t *w_tile_base = w + tile_q * TileP * g / 8; + +#pragma unroll(Bits) + for (int32_t vec_p = 0; vec_p < TileP; vec_p += VecP) + { + // qf32 + // we should guarantee all these belong to the same bits during preprocessing + // i.e., VecBits = VecP = VecC * 4 + HVX_Vector c_vec_0 = vmem(c + (vec_p + 0)); + HVX_Vector c_vec_1 = vmem(c + (vec_p + 32)); + HVX_Vector c_vec_2 = vmem(c + (vec_p + 64)); + HVX_Vector c_vec_3 = vmem(c + (vec_p + 96)); + + // int32_t + HVX_VectorPair c_vec_lo = Q6_W_vcombine_VV(c_vec_2, c_vec_0); + HVX_VectorPair c_vec_hi = Q6_W_vcombine_VV(c_vec_3, c_vec_1); + + const uint8_t *w_base = w_tile_base + vec_p * TileQ * g / 8; + if (!WeightsInVTCM) + { + if (vec_p + VecP < TileP) + { + l2fetch(w_base + VecP * TileQ * g / 8, VecP, VecP, TileQ * g / 8, 0); + } + } + +#pragma unroll + for (int32_t vec_q = 0; vec_q < TileQ; vec_q += VecQ) + { + HVX_Vector w_vec_lo = vmem(w_base + vec_q * VecP * g / 8 + 0); + HVX_Vector w_vec_hi = vmem(w_base + vec_q * VecP * g / 8 + VLEN); + + HVX_Vector w_vec_lo_bo = Q6_V_vand_VV(w_vec_lo, mask_vec); // Q = 0 + HVX_Vector w_vec_hi_bo = Q6_V_vand_VV(w_vec_hi, mask_vec); // Q = 2 + HVX_Vector w_vec_lo_to = Q6_Vh_vasr_VhR(w_vec_lo, shift_len); // Q = 1 + HVX_Vector w_vec_hi_to = Q6_Vh_vasr_VhR(w_vec_hi, shift_len); // Q = 3 + + // int16_t + // c_vec_lo_bo_lo: even bytes of w_vec_lo_bo, c_vec_lo_bo_hi: odd bytes of w_vec_lo_bo + HVX_VectorPair c_vec_lo_bo = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_lo_bo, lvec_arr[vec_q / VecQ], 0); // Q = 0, even lo + HVX_VectorPair c_vec_hi_bo = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_hi_bo, lvec_arr[vec_q / VecQ], 1); // Q = 2, even hi + HVX_VectorPair c_vec_lo_to = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_lo_to, lvec_arr[vec_q / VecQ], 2); // Q = 1, odd lo + HVX_VectorPair c_vec_hi_to = Q6_Wh_vlut16_VbVhR_nomatch(w_vec_hi_to, lvec_arr[vec_q / VecQ], 3); // Q = 3, odd hi + + // int32_t + // c_vec_lo: even bytes of w_vec + // c_vec_hi: odd bytes of w_vec + // TAG0: Here widening add will perform a 2x64 transpose + c_vec_lo = Q6_Ww_vaddacc_WwVhVh(c_vec_lo, Q6_V_lo_W(c_vec_lo_bo), Q6_V_lo_W(c_vec_hi_bo)); + c_vec_hi = Q6_Ww_vaddacc_WwVhVh(c_vec_hi, Q6_V_hi_W(c_vec_lo_bo), Q6_V_hi_W(c_vec_hi_bo)); + c_vec_lo = Q6_Ww_vaddacc_WwVhVh(c_vec_lo, Q6_V_lo_W(c_vec_lo_to), Q6_V_lo_W(c_vec_hi_to)); + c_vec_hi = Q6_Ww_vaddacc_WwVhVh(c_vec_hi, Q6_V_hi_W(c_vec_lo_to), Q6_V_hi_W(c_vec_hi_to)); + } + + vmem(c + (vec_p + 0)) = Q6_V_lo_W(c_vec_lo); + vmem(c + (vec_p + 32)) = Q6_V_lo_W(c_vec_hi); + vmem(c + (vec_p + 64)) = Q6_V_hi_W(c_vec_lo); + vmem(c + (vec_p + 96)) = Q6_V_hi_W(c_vec_hi); + } + } + + HVX_Vector ls_vec = Q6_V_vsplat_R(_fp32_to_bits(ls[0])); + HVX_Vector lb_vec = Q6_V_vsplat_R(_fp32_to_bits(lb[0])); + HVX_Vector s_vec = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vsplat_R(_fp16_to_bits(s)), ones_vec)); +#pragma unroll(Bits) + for (int32_t vec_p = 0; vec_p < TileP; vec_p += VecP) + { +#pragma unroll + for (int32_t vec_c = vec_p; vec_c < vec_p + VecP; vec_c += VecP / sizeof(CType)) + { + HVX_Vector c_vec = vmem(c + vec_c); + c_vec = Q6_Vsf_equals_Vw(c_vec); + c_vec = Q6_Vqf32_vmpy_VsfVsf(c_vec, ls_vec); // * ls + if (vec_p % (VecP * Bits) == VecP) + { + c_vec = Q6_Vqf32_vadd_Vqf32Vsf(c_vec, lb_vec); // + lb + } + c_vec = Q6_Vqf32_vmpy_Vqf32Vqf32(c_vec, s_vec); // * s + vmem(c + vec_c) = c_vec; + } + } + + return 0; +} + +template +inline typename std::enable_if_t::value && std::is_same::value, int> +hvx_bit_serial(int32_t GemmM, int32_t GemmN, const CType *c, XType *y) +{ + UNUSED(GemmN); + + const int32_t P = GemmM * Bits; + + constexpr int32_t VecP = VLEN / sizeof(uint8_t); + // TileP = ThreadP + const int32_t TileP = P; + + static_assert((Bits <= 4 && Bits >= 2), "2 <= Bits <= 4 is required"); // Bits == 1 also works. Just need to multiply lb by 2 + + // Step.2: BIT-SERIAL SUM + const HVX_Vector f0_5_vec = Q6_V_vsplat_R(0x4000007e); // 0.5f + const HVX_Vector f2_0_vec = Q6_V_vsplat_R(0x40000080); // 2.0f + const HVX_Vector f4_0_vec = Q6_V_vsplat_R(0x40000081); // 4.0f + + for (int32_t vec_p = 0; vec_p < TileP; vec_p += VecP * Bits) + { + // VecP / VecC = 4 + HVX_Vector c_bits[Bits * 4]; +#pragma unroll + for (int32_t b = 0; b < Bits * 4; b++) + { + c_bits[b] = vmem(c + (vec_p + b * 32)); + } + +#pragma unroll + for (int32_t i = 0; i < 4; i++) + { + c_bits[i] = Q6_Vqf32_vmpy_Vqf32Vqf32(c_bits[i], f0_5_vec); + } + if (Bits >= 2) + { +#pragma unroll + for (int32_t i = 0; i < 4; i++) + { + c_bits[i] = Q6_Vqf32_vadd_Vqf32Vqf32(c_bits[i], c_bits[i + 4]); + } + } + if (Bits >= 3) + { +#pragma unroll + for (int32_t i = 0; i < 4; i++) + { + c_bits[i + 8] = Q6_Vqf32_vmpy_Vqf32Vqf32(c_bits[i + 8], f2_0_vec); + c_bits[i] = Q6_Vqf32_vadd_Vqf32Vqf32(c_bits[i], c_bits[i + 8]); + } + } + if (Bits == 4) + { +#pragma unroll + for (int32_t i = 0; i < 4; i++) + { + c_bits[i + 12] = Q6_Vqf32_vmpy_Vqf32Vqf32(c_bits[i + 12], f4_0_vec); + c_bits[i] = Q6_Vqf32_vadd_Vqf32Vqf32(c_bits[i], c_bits[i + 12]); + } + } + // TAG1: here narrowing performs a 64x2 transpose to restore TAG0 + HVX_Vector c_bitsum_lo = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(c_bits[2], c_bits[0])); + HVX_Vector c_bitsum_hi = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(c_bits[3], c_bits[1])); + vmem(y + vec_p / Bits + 0) = c_bitsum_lo; + vmem(y + vec_p / Bits + 64) = c_bitsum_hi; + } + + return 0; +} diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/TMANOpPackageInterface.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/TMANOpPackageInterface.cpp new file mode 100644 index 00000000000..8e21fcdb8d7 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/TMANOpPackageInterface.cpp @@ -0,0 +1,288 @@ +//============================================================================== +// Auto Generated Code for TMANOpPackage +//============================================================================== + +#include "HTP/QnnHtpCommon.h" +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "HTP/core/unique_types.h" +#include "QnnOpPackage.h" +#include "QnnSdkBuildId.h" + +DEFINE_UNIQ_TY() +BEGIN_PKG_OPS_OPTS_LIST() + +/** Note that the order of declarations given here defines the order in which ops and graph optimizations are + * registered to the HTP Core. + * Append the latest OpName at the bottom + */ +DECLARE_PKG_OPS_OPTS_LIST(PKG_TMANPrecompute) +DECLARE_PKG_OPS_OPTS_LIST(PKG_TMANLinear) +DECLARE_PKG_OPS_OPTS_LIST(PKG_TMANFinalize) + +END_PKG_OPS_OPTS_LIST() + +// op package info +static constexpr auto sg_packageName = THIS_PKG_NAME_STR; // package name passed in as compile flag + +static std::array sg_opNames{{"TMANLinear", + "TMANFinalize", + "TMANPrecompute"}}; + +static Qnn_ApiVersion_t sg_sdkApiVersion = QNN_HTP_API_VERSION_INIT; +static QnnOpPackage_Info_t sg_packageInfo = QNN_OP_PACKAGE_INFO_INIT; + +// global data +static QnnOpPackage_GlobalInfrastructure_t sg_globalInfra = +nullptr; // global infrastructure not in use for now +static bool sg_packageInitialized = false; + +/* + * user provided logging call back function + * currently only supported on linux x86-64 and nonrpc versions + * typedef void (*QnnLog_Callback_t)(const char* fmt, + * QnnLog_Level_t level, + * uint64_t timestamp, + * va_list args); + * usage: if(sg_logInitialized && level <= sg_maxLogLevel) + * sg_logCallback(fmt, level, timestamp, args); + * + * for cross rpc versions, skel side user provided logging call back function + * can be defined as part of op packages. maximal log level sg_maxLogLevel + * can be set by Qnn_ErrorHandle_t TMANOpPackageLogSetLevel(QnnLog_Level_t maxLogLevel) + */ +/* + * for alternative logging method provided by HTP core, please refer to log.h + */ +static QnnLog_Callback_t sg_logCallback = + nullptr; // user provided call back function pointer for logging +static QnnLog_Level_t sg_maxLogLevel = + (QnnLog_Level_t)0; // maximal log level used in user provided logging +static bool sg_logInitialized = + false; // tracks whether user provided logging method has been initialized + + +/* +* op initialization +* needs to be global in the package +* one initialization per package before any op definitions +* syntax: INIT_PACKAGE_OP_DEF() +*/ +INIT_PACKAGE_OP_DEF() + +/* +* optimization initialization +* needs to be global in the package +* one initialization per package before any optimization definitions +* syntax: INIT_PACKAGE_OPTIMIZATION_DEF() +*/ +INIT_PACKAGE_OPTIMIZATION_DEF() + +/* + * op parameter order initialization + * needs to be global in the package + * one initialization per package before any op parameter order definitions + * syntax: INIT_PACKAGE_PARAM_ORDER_DEF() + */ +INIT_PACKAGE_PARAM_ORDER_DEF() + +/* + * axis parameter name list + * optional + * needs to be global in the package + * one list per package + * for listing axis parameter names passed into Qnn_AddNode API + * HTP backend auto-adjusts values in axis parameters based on HTP backfilling + * note: HTP backend backfills tensor dimensions to 4 dimensions + * syntax: LIST_PACKAGE_AXIS_PARAMS(...) + * e.g. LIST_PACKAGE_AXIS_PARAMS("Axis", "AXIS", "axis") + */ +// LIST_PACKAGE_AXIS_PARAMS() + +/* + * per-channel quantized op name list + * optional + * needs to be global in the package + * one list per package + * for listing op names which support per-channel quantization + * per-axis quantization info of an op is embeded in axisScaleOffsetEncoding + * inside Qnn_Tensor_t types + * HTP backend only supports per-channel scale ops + * i.e. along last dimension, offset is always zero + * if an op name is marked as having per-channel scale support, and in + * QNN_AddNode, at least one input, parameter, or output has + * QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET type: + * then: + * HTP backend will pass to op implementation function the following: + * output(s), input(s), parameter(s), + * outputPerChannelScale(s), inputPerChannelScale(s), paramPerChannelScale(s) + * + * optimization rules can be used to remove extra perChannelScale tensors + * + * syntax: LIST_PACKAGE_PER_CHANNEL_QUANTIZED_OPS(...) + * e.g. LIST_PACKAGE_PER_CHANNEL_QUANTIZED_OPS(sg_op1Name, sg_op2Name) + */ + +// LIST_PACKAGE_PER_CHANNEL_QUANTIZED_OPS() + +/* +* Declare and define the special intialize function for HTP Backend to load +*/ +INIT_PKG_CORE_INIT_FUNC() + +/* op package API's */ + +Qnn_ErrorHandle_t TMANOpPackageInit(QnnOpPackage_GlobalInfrastructure_t infrastructure) { + if (sg_packageInitialized) return QNN_OP_PACKAGE_ERROR_LIBRARY_ALREADY_INITIALIZED; + + /* + * op parameter order registration + * registers all defined op parameter orders in the package + * syntax: REGISTER_PACKAGE_PARAM_ORDERS() + */ + REGISTER_PACKAGE_PARAM_ORDERS() + + /* + * op axis parameter name registration + * registers all axis parameter names in the package + * used with LIST_PACKAGE_AXIS_PARAMS(...) + * syntax: REGISTER_PACKAGE_AXIS_PARAMS() + */ + REGISTER_PACKAGE_AXIS_PARAMS() + + /* + * per-channel scale op name registration + * registers all per-channel scale op names in the package + * used with LIST_PACKAGE_PER_CHANNEL_QUANTIZED_OPS(...) + * syntax: REGISTER_PACKAGE_PER_CHANNEL_QUANTIZED_OPS() + */ + REGISTER_PACKAGE_PER_CHANNEL_QUANTIZED_OPS() + + sg_globalInfra = infrastructure; + sg_packageInitialized = true; + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t TMANOpPackageGetInfo(const QnnOpPackage_Info_t** info) { + if (!sg_packageInitialized) return QNN_OP_PACKAGE_ERROR_LIBRARY_NOT_INITIALIZED; + if (!info) return QNN_OP_PACKAGE_ERROR_INVALID_INFO; + + sg_packageInfo = QNN_OP_PACKAGE_INFO_INIT; + sg_packageInfo.packageName = sg_packageName; + sg_packageInfo.operationNames = sg_opNames.data(); + sg_packageInfo.numOperations = sg_opNames.size(); + sg_packageInfo.sdkBuildId = QNN_SDK_BUILD_ID; + sg_packageInfo.sdkApiVersion = &sg_sdkApiVersion; + + *info = &sg_packageInfo; + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t TMANOpPackageLogInitialize(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) { + if (sg_logInitialized) return QNN_OP_PACKAGE_ERROR_LIBRARY_ALREADY_INITIALIZED; + if (!callback) return QNN_LOG_ERROR_INVALID_ARGUMENT; + if (maxLogLevel < QNN_LOG_LEVEL_ERROR) return QNN_LOG_ERROR_INVALID_ARGUMENT; + sg_logCallback = callback; + sg_maxLogLevel = maxLogLevel; + sg_logInitialized = true; + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t TMANOpPackageLogSetLevel(QnnLog_Level_t maxLogLevel) { + if (maxLogLevel < QNN_LOG_LEVEL_ERROR) return QNN_LOG_ERROR_INVALID_ARGUMENT; + sg_maxLogLevel = maxLogLevel; + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t TMANOpPackageLogTerminate() { + if (!sg_logInitialized) return QNN_OP_PACKAGE_ERROR_LIBRARY_NOT_INITIALIZED; + sg_logCallback = nullptr; + sg_maxLogLevel = (QnnLog_Level_t)0; + sg_logInitialized = false; + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t TMANOpPackageValidateOpConfig (Qnn_OpConfig_t opConfig){ + if (std::string(sg_packageName) != opConfig.v1.packageName) { + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + + /* auto-generated validation code below + * Check if op config type matches any registered ops + * If a match is found, check number of inputs, outputs and params + */ + if (std::string(opConfig.v1.typeName) == "TMANLinear"){ + if (opConfig.v1.numOfParams != 3 || opConfig.v1.numOfInputs != 3 || opConfig.v1.numOfOutputs != 1){ + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + } + else if (std::string(opConfig.v1.typeName) == "TMANFinalize"){ + if (opConfig.v1.numOfParams != 3 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + } + else if (std::string(opConfig.v1.typeName) == "TMANPrecompute"){ + if (opConfig.v1.numOfParams != 3 || opConfig.v1.numOfInputs != 1 || opConfig.v1.numOfOutputs != 1){ + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + } + else{ + return QNN_OP_PACKAGE_ERROR_VALIDATION_FAILURE; + } + + /* + * additional validation code here + * */ + + return QNN_SUCCESS; +} + +/* The following three functions in this comment are not called by HTP backend for now, + * no auto-generated implementations are created. Users should see example for full function signatures. + * (version 1.3.0) Qnn_ErrorHandle_t TMANOpPackageCreateKernels (QnnOpPackage_GraphInfrastructure_t + * graphInfrastructure, QnnOpPackage_Node_t node, QnnOpPackage_Kernel_t** kernels, uint32_t* + * numKernels) + * (version 1.3.0) Qnn_ErrorHandle_t TMANOpPackageFreeKernels (QnnOpPackage_Kernel_t* kernels) + * + * (version 1.4.0) Qnn_ErrorHandle_t TMANOpPackageCreateOpImpl (QnnOpPackage_GraphInfrastructure_t + * graphInfrastructure, QnnOpPackage_Node_t node, QnnOpPackage_OpImpl_t* opImpl) + *(version 1.4.0) Qnn_ErrorHandle_t TMANOpPackageFreeOpImpl (QnnOpPackage_OpImpl_t opImpl) + */ + +Qnn_ErrorHandle_t TMANOpPackageTerminate() { +if (!sg_packageInitialized) return QNN_OP_PACKAGE_ERROR_LIBRARY_NOT_INITIALIZED; + +sg_globalInfra = nullptr; +sg_packageInitialized = false; +return QNN_SUCCESS; +} + +#ifdef __cplusplus +extern "C" { +#endif + + +/* latest version */ +Qnn_ErrorHandle_t TMANOpPackageInterfaceProvider(QnnOpPackage_Interface_t* interface) { + if (!interface) return QNN_OP_PACKAGE_ERROR_INVALID_ARGUMENT; + interface->interfaceVersion = {1, 4, 0}; + interface->v1_4.init = TMANOpPackageInit; + interface->v1_4.terminate = TMANOpPackageTerminate; + interface->v1_4.getInfo = TMANOpPackageGetInfo; + interface->v1_4.validateOpConfig = TMANOpPackageValidateOpConfig; + interface->v1_4.createOpImpl = nullptr; + interface->v1_4.freeOpImpl = nullptr; + interface->v1_4.logInitialize = TMANOpPackageLogInitialize; + interface->v1_4.logSetLevel = TMANOpPackageLogSetLevel; + interface->v1_4.logTerminate = TMANOpPackageLogTerminate; + return QNN_SUCCESS; +} + +#ifdef __cplusplus +} +#endif + + diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_extend.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_extend.cpp new file mode 100644 index 00000000000..a5df613a80c --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_extend.cpp @@ -0,0 +1,110 @@ +// for x86_64 simulation +#ifndef __hexagon__ + +#include +#include + +typedef uint16_t src_t; +typedef uint16_t src_rep_t; +#define SRC_REP_C UINT16_C +static const int srcSigBits = 10; +#define src_rep_t_clz __builtin_clz + +typedef float dst_t; +typedef uint32_t dst_rep_t; +#define DST_REP_C UINT32_C +static const int dstSigBits = 23; + +// End of specialization parameters. Two helper routines for conversion to and +// from the representation of floating-point data as integer values follow. + +static __inline src_rep_t srcToRep(src_t x) { + const union { src_t f; src_rep_t i; } rep = {.f = x}; + return rep.i; +} + +static __inline dst_t dstFromRep(dst_rep_t x) { + const union { dst_t f; dst_rep_t i; } rep = {.i = x}; + return rep.f; +} +// End helper routines. Conversion implementation follows. + +static __inline dst_t __extendXfYf2__(src_t a) { + // Various constants whose values follow from the type parameters. + // Any reasonable optimizer will fold and propagate all of these. + const int srcBits = sizeof(src_t)*CHAR_BIT; + const int srcExpBits = srcBits - srcSigBits - 1; + const int srcInfExp = (1 << srcExpBits) - 1; + const int srcExpBias = srcInfExp >> 1; + + const src_rep_t srcMinNormal = SRC_REP_C(1) << srcSigBits; + const src_rep_t srcInfinity = (src_rep_t)srcInfExp << srcSigBits; + const src_rep_t srcSignMask = SRC_REP_C(1) << (srcSigBits + srcExpBits); + const src_rep_t srcAbsMask = srcSignMask - 1; + const src_rep_t srcQNaN = SRC_REP_C(1) << (srcSigBits - 1); + const src_rep_t srcNaNCode = srcQNaN - 1; + + const int dstBits = sizeof(dst_t)*CHAR_BIT; + const int dstExpBits = dstBits - dstSigBits - 1; + const int dstInfExp = (1 << dstExpBits) - 1; + const int dstExpBias = dstInfExp >> 1; + + const dst_rep_t dstMinNormal = DST_REP_C(1) << dstSigBits; + + // Break a into a sign and representation of the absolute value + const src_rep_t aRep = srcToRep(a); + const src_rep_t aAbs = aRep & srcAbsMask; + const src_rep_t sign = aRep & srcSignMask; + dst_rep_t absResult; + + // If sizeof(src_rep_t) < sizeof(int), the subtraction result is promoted + // to (signed) int. To avoid that, explicitly cast to src_rep_t. + if ((src_rep_t)(aAbs - srcMinNormal) < srcInfinity - srcMinNormal) { + // a is a normal number. + // Extend to the destination type by shifting the significand and + // exponent into the proper position and rebiasing the exponent. + absResult = (dst_rep_t)aAbs << (dstSigBits - srcSigBits); + absResult += (dst_rep_t)(dstExpBias - srcExpBias) << dstSigBits; + } + + else if (aAbs >= srcInfinity) { + // a is NaN or infinity. + // Conjure the result by beginning with infinity, then setting the qNaN + // bit (if needed) and right-aligning the rest of the trailing NaN + // payload field. + absResult = (dst_rep_t)dstInfExp << dstSigBits; + absResult |= (dst_rep_t)(aAbs & srcQNaN) << (dstSigBits - srcSigBits); + absResult |= (dst_rep_t)(aAbs & srcNaNCode) << (dstSigBits - srcSigBits); + } + + else if (aAbs) { + // a is denormal. + // renormalize the significand and clear the leading bit, then insert + // the correct adjusted exponent in the destination type. + const int scale = src_rep_t_clz(aAbs) - src_rep_t_clz(srcMinNormal); + absResult = (dst_rep_t)aAbs << (dstSigBits - srcSigBits + scale); + absResult ^= dstMinNormal; + const int resultExponent = dstExpBias - srcExpBias - scale + 1; + absResult |= (dst_rep_t)resultExponent << dstSigBits; + } + + else { + // a is zero. + absResult = 0; + } + + // Apply the signbit to (dst_t)abs(a). + const dst_rep_t result = absResult | (dst_rep_t)sign << (dstBits - srcBits); + return dstFromRep(result); +} +// Use a forwarding definition and noinline to implement a poor man's alias, +// as there isn't a good cross-platform way of defining one. +__attribute__((noinline)) float __extendhfsf2(uint16_t a) { + return __extendXfYf2__(a); +} + +extern "C" float __gnu_h2f_ieee(uint16_t a) { + return __extendhfsf2(a); +} + +#endif // !__hexagon__ diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_trunc.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_trunc.cpp new file mode 100644 index 00000000000..71a192c9b95 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/fp_trunc.cpp @@ -0,0 +1,134 @@ +// for x86_64 simulation +#ifndef __hexagon__ + +#include +#include + +typedef float src_t; +typedef uint32_t src_rep_t; +#define SRC_REP_C UINT32_C +static const int srcSigBits = 23; + +typedef uint16_t dst_t; +typedef uint16_t dst_rep_t; +#define DST_REP_C UINT16_C +static const int dstSigBits = 10; + +// End of specialization parameters. Two helper routines for conversion to and +// from the representation of floating-point data as integer values follow. + +static __inline src_rep_t srcToRep(src_t x) { + const union { src_t f; src_rep_t i; } rep = {.f = x}; + return rep.i; +} + +static __inline dst_t dstFromRep(dst_rep_t x) { + const union { dst_t f; dst_rep_t i; } rep = {.i = x}; + return rep.f; +} + +static __inline dst_t __truncXfYf2__(src_t a) { + // Various constants whose values follow from the type parameters. + // Any reasonable optimizer will fold and propagate all of these. + const int srcBits = sizeof(src_t)*CHAR_BIT; + const int srcExpBits = srcBits - srcSigBits - 1; + const int srcInfExp = (1 << srcExpBits) - 1; + const int srcExpBias = srcInfExp >> 1; + + const src_rep_t srcMinNormal = SRC_REP_C(1) << srcSigBits; + const src_rep_t srcSignificandMask = srcMinNormal - 1; + const src_rep_t srcInfinity = (src_rep_t)srcInfExp << srcSigBits; + const src_rep_t srcSignMask = SRC_REP_C(1) << (srcSigBits + srcExpBits); + const src_rep_t srcAbsMask = srcSignMask - 1; + const src_rep_t roundMask = (SRC_REP_C(1) << (srcSigBits - dstSigBits)) - 1; + const src_rep_t halfway = SRC_REP_C(1) << (srcSigBits - dstSigBits - 1); + const src_rep_t srcQNaN = SRC_REP_C(1) << (srcSigBits - 1); + const src_rep_t srcNaNCode = srcQNaN - 1; + + const int dstBits = sizeof(dst_t)*CHAR_BIT; + const int dstExpBits = dstBits - dstSigBits - 1; + const int dstInfExp = (1 << dstExpBits) - 1; + const int dstExpBias = dstInfExp >> 1; + const int underflowExponent = srcExpBias + 1 - dstExpBias; + const int overflowExponent = srcExpBias + dstInfExp - dstExpBias; + const src_rep_t underflow = (src_rep_t)underflowExponent << srcSigBits; + const src_rep_t overflow = (src_rep_t)overflowExponent << srcSigBits; + + const dst_rep_t dstQNaN = DST_REP_C(1) << (dstSigBits - 1); + const dst_rep_t dstNaNCode = dstQNaN - 1; + + // Break a into a sign and representation of the absolute value + const src_rep_t aRep = srcToRep(a); + const src_rep_t aAbs = aRep & srcAbsMask; + const src_rep_t sign = aRep & srcSignMask; + dst_rep_t absResult; + + if (aAbs - underflow < aAbs - overflow) { + // The exponent of a is within the range of normal numbers in the + // destination format. We can convert by simply right-shifting with + // rounding and adjusting the exponent. + absResult = aAbs >> (srcSigBits - dstSigBits); + absResult -= (dst_rep_t)(srcExpBias - dstExpBias) << dstSigBits; + + const src_rep_t roundBits = aAbs & roundMask; + // Round to nearest + if (roundBits > halfway) + absResult++; + // Ties to even + else if (roundBits == halfway) + absResult += absResult & 1; + } + else if (aAbs > srcInfinity) { + // a is NaN. + // Conjure the result by beginning with infinity, setting the qNaN + // bit and inserting the (truncated) trailing NaN field. + absResult = (dst_rep_t)dstInfExp << dstSigBits; + absResult |= dstQNaN; + absResult |= ((aAbs & srcNaNCode) >> (srcSigBits - dstSigBits)) & dstNaNCode; + } + else if (aAbs >= overflow) { + // a overflows to infinity. + absResult = (dst_rep_t)dstInfExp << dstSigBits; + } + else { + // a underflows on conversion to the destination type or is an exact + // zero. The result may be a denormal or zero. Extract the exponent + // to get the shift amount for the denormalization. + const int aExp = aAbs >> srcSigBits; + const int shift = srcExpBias - dstExpBias - aExp + 1; + + const src_rep_t significand = (aRep & srcSignificandMask) | srcMinNormal; + + // Right shift by the denormalization amount with sticky. + if (shift > srcSigBits) { + absResult = 0; + } else { + const bool sticky = significand << (srcBits - shift); + src_rep_t denormalizedSignificand = significand >> shift | sticky; + absResult = denormalizedSignificand >> (srcSigBits - dstSigBits); + const src_rep_t roundBits = denormalizedSignificand & roundMask; + // Round to nearest + if (roundBits > halfway) + absResult++; + // Ties to even + else if (roundBits == halfway) + absResult += absResult & 1; + } + } + + // Apply the signbit to (dst_t)abs(a). + const dst_rep_t result = absResult | sign >> (srcBits - dstBits); + return dstFromRep(result); +} + +// Use a forwarding definition and noinline to implement a poor man's alias, +// as there isn't a good cross-platform way of defining one. +__attribute__((noinline)) uint16_t __truncsfhf2(float a) { + return __truncXfYf2__(a); +} + +extern "C" uint16_t __gnu_f2h_ieee(float a) { + return __truncsfhf2(a); +} + +#endif // !__hexagon__ diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANFinalize.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANFinalize.cpp new file mode 100644 index 00000000000..4eb08af21e5 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANFinalize.cpp @@ -0,0 +1,102 @@ +//============================================================================== +// Auto Generated Code for TMANOpPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "QnnOpPackage.h" +#include "HTP/core/simple_reg.h" + +#include "hvx_funcs.h" + +BEGIN_PKG_OP_DEFINITION(PKG_TMANFinalize); + +static Qnn_Scalar_t sg_opDefaultGroup_SizeScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 64}; +static Qnn_Param_t sg_opDefaultGroup_Size = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultGroup_SizeScalar}; +static Qnn_Scalar_t sg_opDefaultBitsScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 2}; +static Qnn_Param_t sg_opDefaultBits = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultBitsScalar}; +static Qnn_Scalar_t sg_opDefaultSymmetricScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 0}; +static Qnn_Param_t sg_opDefaultSymmetric = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultSymmetricScalar}; + +template +GraphStatus tmanfinalizeImpl(TensorType& y, + const TensorType& c, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric); + +static float tmanfinalizeCostFunc(const Op *op); + +DEF_PACKAGE_OP((tmanfinalizeImpl), "TMANFinalize") + +// Tcm("y") results in [ERROR] [Qnn ExecuTorch]: graph_prepare.cc:217:ERROR:could not create op: q::Add.tcm +// Reason: embedding (Gather) outputs are in MainMemory +// but TMANLinear outputs are in Tcm +// add(embedding, TMANLinear) thus causes a conflict +// TODO: +// - implement custom TMANOpPackage::Add +DEF_TENSOR_PROPERTIES(Op("TMANFinalize", "c", "group_size", "bits", "symmetric"), + Flat("*", "c"), + MainMemory("*", "group_size", "bits", "symmetric"), + Tcm("c")) + +DEF_PACKAGE_PARAM_ORDER("TMANFinalize", + "group_size", + false, + &sg_opDefaultGroup_Size, + "bits", + false, + &sg_opDefaultBits, + "symmetric", + false, + &sg_opDefaultSymmetric) + +template +GraphStatus tmanfinalizeImpl(TensorType& y, + const TensorType& c, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric) +{ + using XType = __fp16; + using CType = float; + + const int32_t gemm_m = y.dims()[3]; + const int32_t gemm_n = y.dims()[2]; + + const int32_t bits = ((const int32_t*)t_bits.raw_data_const())[0]; + + const CType* c_ptr = (const CType*)c.raw_data_const(); + XType* y_ptr = (XType*)y.raw_data(); + + if (bits == 2) + { + hvx_bit_serial(gemm_m, gemm_n, c_ptr, y_ptr); + } + else if (bits == 4) + { + hvx_bit_serial(gemm_m, gemm_n, c_ptr, y_ptr); + } + else + { + return GraphStatus::ErrorDimensions; + } + + return GraphStatus::Success; +} + +__attribute__((unused)) static float tmanfinalizeCostFunc(const Op *op) +{ + float cost = 0.0; // add cost computation here + return cost; +} + +END_PKG_OP_DEFINITION(PKG_TMANFinalize); diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp new file mode 100644 index 00000000000..d33f5494f9b --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp @@ -0,0 +1,182 @@ +//============================================================================== +// Auto Generated Code for TMANOpPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "QnnOpPackage.h" +#include "HTP/core/simple_reg.h" + +#include "hvx_funcs.h" + +#ifndef PREPARE_DISABLED +API_EXPORT QuickShape simpledim_chunk1_4d_split_start(Replacement &rpx, Split_Context const &splitinfo, OpRef const &orig, int dim) +{ + size_t dims[4] = { 0, 0, 0, 0 }; + dims[dim] = splitinfo.start / splitinfo.size; + return QuickShape(dims[0], dims[1], dims[2], dims[3]); +} + +API_EXPORT QuickShape simpledim_chunk1_4d_split_size(Replacement &rpx, Split_Context const &splitinfo, OpRef const &orig, int dim) +{ + size_t dims[4] = { + orig.dim(rpx.graph(), 0), + orig.dim(rpx.graph(), 1), + orig.dim(rpx.graph(), 2), + orig.dim(rpx.graph(), 3) + }; + dims[dim] = 1; + return QuickShape(dims[0], dims[1], dims[2], dims[3]); +} +#endif + +BEGIN_PKG_OP_DEFINITION(PKG_TMANLinear); + +static Qnn_Scalar_t sg_opDefaultGroup_SizeScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 64}; +static Qnn_Param_t sg_opDefaultGroup_Size = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultGroup_SizeScalar}; +static Qnn_Scalar_t sg_opDefaultBitsScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 2}; +static Qnn_Param_t sg_opDefaultBits = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultBitsScalar}; +static Qnn_Scalar_t sg_opDefaultSymmetricScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 0}; +static Qnn_Param_t sg_opDefaultSymmetric = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultSymmetricScalar}; + +template +GraphStatus tmanlinearImpl(TensorType& c, + const TensorType& l, + const TensorType& qweight, + const TensorType& scales, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric); + +static float tmanlinearCostFunc(const Op *op); + +DEF_PACKAGE_OP((tmanlinearImpl), "TMANLinear") + +DEF_TENSOR_PROPERTIES( + Op("TMANLinear", "l", "qweight", "scales", "group_size", "bits", "symmetric"), + Flat("*", "qweight", "scales"), + MainMemory("qweight", "scales", "group_size", "bits", "symmetric"), + Tcm("*", "l")) + +#define SIZE_OF(WEIGHT) MUL(ELEMENTSIZE_OF(WEIGHT), DIM_OF(WEIGHT, 0), DIM_OF(WEIGHT, 1), DIM_OF(WEIGHT, 2), DIM_OF(WEIGHT, 3)) + +// GPTQ +DEF_PACKAGE_OPTIMIZATION( + EARLY, + Op("TMANLinear", "l", "qweight", "scales", "group_size", "bits", "symmetric"), + AND(GT(DIM_OF("qweight", 2), 1), GT(SIZE_OF("scales"), 128)), + AUTOSPLIT(3, "I", DIV(DIM_OF("*", 3), DIM_OF("qweight", 2)), + Op( + "TMANLinear", "l", + AUTOSPLIT_SLICE("qweight", + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_start, "I", "qweight", 2), + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_size, "I", "qweight", 2)), + AUTOSPLIT_SLICE("scales", + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_start, "I", "scales", 2), + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_size, "I", "scales", 2)), + "group_size", "bits", "symmetric"))) + +// BitNet: weight scale shouldn't be split +DEF_PACKAGE_OPTIMIZATION( + EARLY + 1, + Op("TMANLinear", "l", "qweight", "scales", "group_size", "bits", "symmetric"), + AND(GT(DIM_OF("qweight", 2), 1), LE(SIZE_OF("scales"), 128)), + AUTOSPLIT(3, "I", DIV(DIM_OF("*", 3), DIM_OF("qweight", 2)), + Op( + "TMANLinear", "l", + AUTOSPLIT_SLICE("qweight", + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_start, "I", "qweight", 2), + AUTOSPLIT_SHAPEFN_APPLY(simpledim_chunk1_4d_split_size, "I", "qweight", 2)), + "scales", "group_size", "bits", "symmetric"))) + +DEF_PACKAGE_PARAM_ORDER("TMANLinear", + "group_size", + false, + &sg_opDefaultGroup_Size, + "bits", + false, + &sg_opDefaultBits, + "symmetric", + false, + &sg_opDefaultSymmetric) + +template +GraphStatus tmanlinearImpl(TensorType& c, + const TensorType& l, + const TensorType& qweight, + const TensorType& scales, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric) +{ + using LType = int16_t; + using XType = __fp16; + using CType = float; + + constexpr int32_t ACT_GROUP_SIZE = 256; + constexpr int32_t LUT_G = 4; + constexpr int32_t LUT_SIZE = 16; + constexpr int32_t TILE_K = 256; + + const int32_t group_size = ((const int32_t*)t_group_size.raw_data_const())[0]; + const int32_t bits = ((const int32_t*)t_bits.raw_data_const())[0]; + const bool zero_point = ((const int32_t*)t_symmetric.raw_data_const())[0] == 0; + + const int32_t gemm_n = c.dims()[2]; + const int32_t gemm_m = c.dims()[3] / sizeof(float) / bits; + const int32_t gemm_k = qweight.dims()[2] * qweight.dims()[3] * 32 / bits / gemm_m; + + const int32_t l_size = gemm_k / LUT_G * LUT_SIZE; + const int32_t ls_size = (ACT_GROUP_SIZE == -1) ? 1 : (gemm_k / ACT_GROUP_SIZE); + + const LType* l_ptr = (const LType*)l.raw_data_const(); + const float* ls_ptr = (const float*)(l_ptr + l_size); + const float* lb_ptr = ls_ptr + MAX(ls_size, 128 / sizeof(float)); + + const uint8_t* w_ptr = (const uint8_t*)qweight.raw_data_const(); + const XType* s_ptr = (const XType*)scales.raw_data_const(); + CType* c_ptr = (CType*)c.raw_data(); + + if (zero_point && bits == 2 && group_size == 64) // w2g64, symmetric=False + { + hvx_tbl(gemm_m, gemm_k, gemm_n, l_ptr, ls_ptr, lb_ptr, w_ptr, s_ptr, c_ptr); + } + else if (!zero_point && bits == 4 && group_size == 128) // w4g128, symmetric=True + { + hvx_tbl(gemm_m, gemm_k, gemm_n, l_ptr, ls_ptr, lb_ptr, w_ptr, s_ptr, c_ptr); + } + else if (zero_point && bits == 4 && group_size == 128) // w4g128, symmetric=False + { + hvx_tbl(gemm_m, gemm_k, gemm_n, l_ptr, ls_ptr, lb_ptr, w_ptr, s_ptr, c_ptr); + } + else if (zero_point && bits == 4 && group_size == 64) // w4g64, symmetric=False + { + hvx_tbl(gemm_m, gemm_k, gemm_n, l_ptr, ls_ptr, lb_ptr, w_ptr, s_ptr, c_ptr); + } + else if (!zero_point && bits == 2 && group_size == 0) // bitnet + { + hvx_tbl(gemm_m, gemm_k, gemm_n, l_ptr, ls_ptr, lb_ptr, w_ptr, s_ptr, c_ptr); + } + else + { + return GraphStatus::ErrorDimensions; + } + + return GraphStatus::Success; +} + +__attribute__((unused)) static float tmanlinearCostFunc(const Op *op) +{ + float cost = 0.0; // add cost computation here + return cost; +} + +END_PKG_OP_DEFINITION(PKG_TMANLinear); diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANPrecompute.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANPrecompute.cpp new file mode 100644 index 00000000000..f6114fb0fb4 --- /dev/null +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANPrecompute.cpp @@ -0,0 +1,115 @@ +//============================================================================== +// Auto Generated Code for TMANOpPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "QnnOpPackage.h" +#include "HTP/core/simple_reg.h" + +#include "hvx_funcs.h" + +BEGIN_PKG_OP_DEFINITION(PKG_TMANPrecompute); + +static Qnn_Scalar_t sg_opDefaultGroup_SizeScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 64}; +static Qnn_Param_t sg_opDefaultGroup_Size = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultGroup_SizeScalar}; +static Qnn_Scalar_t sg_opDefaultBitsScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 2}; +static Qnn_Param_t sg_opDefaultBits = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultBitsScalar}; +static Qnn_Scalar_t sg_opDefaultSymmetricScalar = {.dataType = Qnn_DataType_t::QNN_DATATYPE_INT_32, + .int32Value = 0}; +static Qnn_Param_t sg_opDefaultSymmetric = {.paramType = QNN_PARAMTYPE_SCALAR, + .scalarParam = sg_opDefaultSymmetricScalar}; + +template +GraphStatus tmanprecomputeImpl(TensorType& l, + const TensorType& x, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric); + +static float tmanprecomputeCostFunc(const Op *op); + +DEF_PACKAGE_OP((tmanprecomputeImpl), "TMANPrecompute") + +DEF_TENSOR_PROPERTIES(Op("TMANPrecompute", "x", "group_size", "bits", "symmetric"), + Flat("*", "x"), + MainMemory("group_size", "bits", "symmetric"), + Tcm("*", "x")) + +DEF_PACKAGE_PARAM_ORDER("TMANPrecompute", + "group_size", + false, + &sg_opDefaultGroup_Size, + "bits", + false, + &sg_opDefaultBits, + "symmetric", + false, + &sg_opDefaultSymmetric) + +template +GraphStatus tmanprecomputeImpl(TensorType& l, + const TensorType& x, + const Int32Tensor& t_group_size, + const Int32Tensor& t_bits, + const Int32Tensor& t_symmetric) +{ + using LType = int16_t; + using XType = __fp16; + + constexpr int32_t ACT_GROUP_SIZE = 256; + constexpr int32_t LUT_G = 4; + constexpr int32_t LUT_SIZE = 16; + + const int32_t gemm_k = x.dims()[3]; + const int32_t gemm_n = x.dims()[2]; + + const int32_t group_size = ((const int32_t*)t_group_size.raw_data_const())[0]; + const bool zero_point = ((const int32_t*)t_symmetric.raw_data_const())[0] == 0; + + const int32_t l_size = gemm_k / LUT_G * LUT_SIZE; + const int32_t real_act_group_size = (group_size == 0) ? -1 : ACT_GROUP_SIZE; + const int32_t ls_size = (real_act_group_size == -1) ? 1 : (gemm_k / real_act_group_size); + + const XType* x_ptr = (const XType*)x.raw_data_const(); + LType* l_ptr = (LType*)l.raw_data(); + float* ls_ptr = (float*)(l_ptr + l_size); + float* lb_ptr = ls_ptr + MAX(ls_size, 128 / sizeof(float)); + + if (zero_point && group_size == 64) // w2g64, symmetric=False + { + hvx_lut_ctor(gemm_k, gemm_n, x_ptr, l_ptr, ls_ptr, lb_ptr); + } + else if (!zero_point && group_size == 128) // w4g128, symmetric=True + { + hvx_lut_ctor(gemm_k, gemm_n, x_ptr, l_ptr, ls_ptr, lb_ptr); + } + else if (zero_point && group_size == 128) // w4g128, symmetric=False + { + hvx_lut_ctor(gemm_k, gemm_n, x_ptr, l_ptr, ls_ptr, lb_ptr); + } + else if (!zero_point && group_size == 0) // bitnet + { + hvx_lut_ctor(gemm_k, gemm_n, x_ptr, l_ptr, ls_ptr, lb_ptr); + } + else + { + return GraphStatus::ErrorDimensions; + } + + return GraphStatus::Success; +} + +__attribute__((unused)) static float tmanprecomputeCostFunc(const Op *op) +{ + float cost = 0.0; // add cost computation here + return cost; +} + +END_PKG_OP_DEFINITION(PKG_TMANPrecompute); diff --git a/backends/qualcomm/tests/test_qnn_manager.py b/backends/qualcomm/tests/test_qnn_manager.py new file mode 100644 index 00000000000..dc2d0756e20 --- /dev/null +++ b/backends/qualcomm/tests/test_qnn_manager.py @@ -0,0 +1,27 @@ +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + PyQnnManagerAdaptor, +) +from executorch.backends.qualcomm.serialization.qc_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.partition.qnn_partitioner import ( + generate_qnn_executorch_option, +) + +dummy_compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8650, + backend_options=generate_htp_compiler_spec(use_fp16=False), +) +qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(dummy_compiler_specs) +) +qnn_mgr.Init() +qnn_mgr.Destroy() + +qnn_mgr = PyQnnManagerAdaptor.QnnManager( + generate_qnn_executorch_option(dummy_compiler_specs) +) +qnn_mgr.Init() +qnn_mgr.Destroy() diff --git a/backends/qualcomm/tests/test_tman_linear.py b/backends/qualcomm/tests/test_tman_linear.py new file mode 100644 index 00000000000..66b21047897 --- /dev/null +++ b/backends/qualcomm/tests/test_tman_linear.py @@ -0,0 +1,41 @@ +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +import torch +import os +import numpy as np + +M, K, N = 2048, 8192, 1 +bits = 4 +group_size = 128 + +qlinear = TorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=K, + out_features=M, + bias=False, +) +data_path = f"m{M}_k{K}_g{group_size}/" +qlinear.load_state_dict(torch.load(os.path.join(data_path, "qlinear.pt"))) +qlinear.post_init() + +x = torch.from_numpy(np.fromfile(os.path.join(data_path, "x.bin"), dtype=np.float16)).reshape(N, K) +y_ref = qlinear.forward(x) + +from executorch.backends.qualcomm.builders.utils import unpack_gptqv2, hvx_preprocess_weights +w, scales, zeros, _, _, _ = unpack_gptqv2(qlinear.qweight.numpy(), qlinear.scales.numpy(), qlinear.qzeros.numpy()) +w.tofile(os.path.join(data_path, "w_unpacked.bin")) +scales.tofile(os.path.join(data_path, "s_unpacked.bin")) +zeros.tofile(os.path.join(data_path, "z_unpacked.bin")) + +w_dq = w.T.reshape(K // group_size, group_size, M).astype(np.float16) - (2 ** (bits - 1)) +w_dq = w_dq.transpose(1, 0, 2) * scales.T +w_dq = w_dq - zeros.T +w_dq = w_dq.transpose(1, 0, 2).reshape(K, M) + +y_ref2 = x.numpy().dot(w_dq) + +qweight_repacked, scales_repacked = hvx_preprocess_weights(w, scales, zeros, bits, tile_p=M*bits) +qweight_repacked.tofile(os.path.join(data_path, "w_repacked.bin")) +scales_repacked.tofile(os.path.join(data_path, "s_repacked.bin")) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 3653cd3176f..c7714f5a279 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -24,6 +24,11 @@ QNN_TENSOR_TYPE_MAP, ) from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader +from executorch.backends.qualcomm.builders.custom_ops import ( + tman_linear, + tman_bitnet_linear, +) +from executorch.backends.qualcomm.builders.utils import unpack_weights from executorch.backends.qualcomm.partition.qnn_partitioner import ( generate_qnn_executorch_option, get_skip_decomp_table, @@ -145,6 +150,218 @@ def qnn_edge_config() -> exir.EdgeCompileConfig: ) +def convert_linear_to_qlinear(module: torch.nn.Module, qlinear_cls): + from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + def replace_linear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, torch.nn.Linear): + qlinear = qlinear_cls( + in_features=target_attr.in_features, + out_features=target_attr.out_features, + bias=target_attr.bias is not None, + ) + # The model should have been converted to gptq_v2 in convert_gptq_weights_to_llama.py + qlinear.qzero_format(2) + assert isinstance(qlinear, TorchQuantLinear) + setattr(module, attr_str, qlinear) + + for _, sub_module in module.named_children(): + sub_module = replace_linear(sub_module) + return module + + return replace_linear(module) + + +def convert_qlinear_to_linear(module: torch.nn.Module): + from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear, BaseQuantLinear + + def replace_qlinear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, BaseQuantLinear): + if not isinstance(target_attr, TorchQuantLinear): + raise RuntimeError("Only GPTQ TorchQuantLinear backend is supported") + target_attr.post_init() + new_attr = torch.nn.Linear(target_attr.in_features, target_attr.out_features) + new_attr.weight = torch.nn.Parameter(target_attr.dequantize_weight().T.detach().to("cpu", torch.float16)) + new_attr.bias = torch.nn.Parameter(target_attr.bias) if target_attr.bias is not None else None + setattr(module, attr_str, new_attr) + + for _, sub_module in module.named_children(): + sub_module = replace_qlinear(sub_module) + return module + + return replace_qlinear(module) + + +class TMANLinear(torch.nn.Module): + def __init__(self, qlinear: torch.nn.Module, n_splits: int = 1): + super().__init__() + # GPTQv1: AutoGPTQ + # GPTQv2: GPTQModel + self.gptq_v2 = qlinear.qzero_format() == 2 + self.in_features = qlinear.in_features + self.out_features = qlinear.out_features // n_splits + + _, _, _, bits, group_size, symmetric = unpack_gptqv2( + qlinear.qweight.detach().numpy(), + qlinear.scales.detach().numpy(), + qlinear.qzeros.detach().numpy(), + self.gptq_v2, + ) + + if n_splits == 1: + self.qweight = torch.nn.Parameter(qlinear.qweight, requires_grad=False) + self.scales = torch.nn.Parameter(qlinear.scales, requires_grad=False) + self.qzeros = torch.nn.Parameter(qlinear.qzeros, requires_grad=False) + else: + self.qweight = torch.nn.Parameter(qlinear.qweight.new_zeros((qlinear.qweight.shape[0], qlinear.qweight.shape[1] // n_splits)), requires_grad=False) + self.scales = torch.nn.Parameter(qlinear.scales.new_zeros((qlinear.scales.shape[0], qlinear.scales.shape[1] // n_splits)), requires_grad=False) + self.qzeros = torch.nn.Parameter(qlinear.qzeros.new_zeros((qlinear.qzeros.shape[0], qlinear.qzeros.shape[1] // n_splits)), requires_grad=False) + + self.g_idx = torch.nn.Parameter(qlinear.g_idx, requires_grad=False) + self.wf_unsqueeze_zero = torch.nn.Parameter(qlinear.wf_unsqueeze_zero, requires_grad=False) + self.wf_unsqueeze_neg_one = torch.nn.Parameter(qlinear.wf_unsqueeze_neg_one, requires_grad=False) + + self.group_size = group_size + self.bits = bits + self.symmetric = symmetric + + def forward(self, x): + return tman_linear( + x, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.wf_unsqueeze_zero, + self.wf_unsqueeze_neg_one, + self.group_size, + self.bits, + self.symmetric, + self.gptq_v2, + ) + + def extra_repr(self): + s = ( + "{in_features}, {out_features}, group_size={group_size}, bits={bits}" + ", symmetric={symmetric}" + ) + return s.format(**self.__dict__) + + +def convert_qlinear_to_tman_linear(module: torch.nn.Module): + from gptqmodel.nn_modules.qlinear import BaseQuantLinear + from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + + def replace_qlinear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, BaseQuantLinear): + if not isinstance(target_attr, TorchQuantLinear): + raise RuntimeError("Only GPTQ TorchQuantLinear backend is supported") + target_attr.post_init() + setattr(module, attr_str, TMANLinear(target_attr)) + + for _, sub_module in module.named_children(): + sub_module = replace_qlinear(sub_module) + return module + + return replace_qlinear(module) + + +# https://github.com/huggingface/transformers/pull/37742 +class BitLinear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None): + super().__init__() + self.dtype = dtype + self.in_features = in_features + self.out_features = out_features + VALUES_PER_ITEM = 4 + self.register_buffer( + "weight", + torch.zeros( + (out_features // VALUES_PER_ITEM, in_features), + dtype=torch.uint8, + device=device, + ), + ) + self.register_buffer( + "weight_scale", + torch.ones( + (1), + dtype=dtype, + device=device, + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device)) + else: + self.bias = None + + def forward(self, input): + y = tman_bitnet_linear(input, self.weight, self.weight_scale) + if self.bias is not None: + y += self.bias.view(1, -1).expand_as(y) + return y + + +def convert_linear_to_bitlinear(module: torch.nn.Module): + def replace_linear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, torch.nn.Linear): + setattr(module, attr_str, BitLinear( + target_attr.in_features, + target_attr.out_features, + target_attr.bias is not None, + )) + + for _, sub_module in module.named_children(): + sub_module = replace_linear(sub_module) + return module + + return replace_linear(module) + + +def convert_bitlinear_to_linear(module: torch.nn.Module): + def replace_bitlinear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, BitLinear): + new_attr = torch.nn.Linear(target_attr.in_features, target_attr.out_features, bias=target_attr.bias is not None) + new_attr.weight = torch.nn.Parameter(unpack_weights(target_attr.weight, dtype=target_attr.weight_scale.dtype) * target_attr.weight_scale) + new_attr.bias = torch.nn.Parameter(target_attr.bias) if target_attr.bias is not None else None + setattr(module, attr_str, new_attr) + + for _, sub_module in module.named_children(): + sub_module = replace_bitlinear(sub_module) + return module + + return replace_bitlinear(module) + + def convert_linear_to_conv2d(module: torch.nn.Module): class Conv2D(torch.nn.Module): def __init__(self, weight, bias=None): diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml index 52bf533521a..84172982c54 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -22,7 +22,7 @@ android:layout_height="wrap_content" android:paddingLeft="20dp" android:paddingTop="20dp" - android:text="Chat with Llama" + android:text="Chat with BitNet" android:textColor="@android:color/white" android:textSize="16sp" android:textStyle="bold" /> diff --git a/examples/qualcomm/oss_scripts/bitnet/bitnet.py b/examples/qualcomm/oss_scripts/bitnet/bitnet.py new file mode 100644 index 00000000000..5a5694d2b2a --- /dev/null +++ b/examples/qualcomm/oss_scripts/bitnet/bitnet.py @@ -0,0 +1,1186 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: reenable pyre after fixing the issues +# pyre-ignore-all-errors + +import copy +import getpass +import json +import logging +import os +import subprocess +import sys +import time +from collections import OrderedDict +from functools import partial +from multiprocessing.connection import Client + +import torch +from executorch.backends.qualcomm._passes.constant_i64_to_i32 import ConstantI64toI32 + +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner + +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_linear_16a8w_in_affine_layer, + annotate_matmul_16a8w, + annotate_prefill_kv_output, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset + +from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, + option_to_flatbuffer, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANTIZED_IO, +) +from executorch.backends.qualcomm.utils.utils import ( + capture_program, + convert_linear_to_conv2d, + convert_linear_to_bitlinear, + convert_bitlinear_to_linear, + generate_composite_llama_program, + generate_htp_compiler_spec, + generate_multi_graph_program, + generate_qnn_executorch_compiler_spec, + get_capture_program_passes, + get_soc_to_chipset_map, + update_spill_fill_size, +) + +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) +from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken +from executorch.examples.qualcomm.oss_scripts.bitnet.model.static_bitnet import ( + BitNetForCausalLM, + BitNetConfig, + BitNetDecoderLayer, +) +from executorch.examples.qualcomm.utils import ( + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.extension.llm.custom_ops import model_sharding +from executorch.extension.llm.export.builder import DType +from executorch.extension.llm.tokenizer.tokenizer import ( + Tokenizer as SentencePieceTokenizer, +) +from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer +from executorch.extension.llm.tokenizer.utils import get_tokenizer + +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from safetensors.torch import load_file + + +sys.setrecursionlimit(4096) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def smart_mask_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + for i, k_cache in enumerate(k_caches): + k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0] + + for i, v_cache in enumerate(v_caches): + v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :] + atten_mask[:, :, pos - ar_len] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def shift_pointer_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + atten_mask[:, :, -pos - 1] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def _kv_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + updater=smart_mask_updater, + use_i64_token=False, +): + _, atten_mask, _, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0) + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, Tiktoken): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) if len(token_list) < ar_len else ar_len + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor( + token_list[pos - ar_len : pos], dtype=dtype + ).reshape(1, -1) + tmp_pos = all_pos[:, pos - ar_len : pos] + tmp_atten_mask = atten_mask + if pos < ar_len: + tmp_token_list = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=dtype), + torch.tensor(token_list, dtype=dtype).reshape(1, -1), + ], + dim=1, + ) + tmp_pos = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=torch.int32), + all_pos[:, :pos], + ], + dim=1, + ) + tmp_atten_mask = torch.cat( + [ + torch.ones(1, ar_len, max_seq_len - pos) * -255.0, + atten_mask[:, :, -pos:], + ], + dim=-1, + ) + + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + tmp_atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) + atten_mask, pos, k_caches, v_caches = updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + ) + if pos > len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"kv calibration data:\n{tokenizer.decode(token_list)}") + + +def _prefill_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + max_seq_len=512, + use_i64_token=False, +): + _, atten_mask = example_inputs + + # TODO: change criteria & support batch inputs if necessary + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, Tiktoken): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1) + if pos < max_seq_len: + tmp_token_list = torch.cat( + [ + tmp_token_list, + torch.zeros((1, max_seq_len - pos), dtype=dtype), + ], + dim=1, + ) + results = module( + tmp_token_list, + atten_mask, + ) + if len(results) == 3: + logits, new_k_caches, new_v_caches = results + elif len(results) == 1: + logits = results + token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item()) + pos += 1 + + print(f"prefill calibration data:\n{tokenizer.decode(token_list)}") + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + kv_updater=smart_mask_updater, + use_i64_token=False, +): + if len(example_inputs) == 2: + _prefill_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + max_seq_len, + use_i64_token, + ) + elif len(example_inputs) == 5: + _kv_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + ar_len, + max_seq_len, + updater=kv_updater, + use_i64_token=use_i64_token, + ) + else: + raise RuntimeError("Get wrong inputs") + + +def permute(weights: torch.Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +class SingleLlama: + def __init__(self, llama_model, pte_filename) -> None: + super().__init__() + self.llama_model = llama_model + self.quant_dtype = None + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + self.pte_filename = pte_filename + if self.llama_meta["get_use_kv_cache"]: + tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( + use_kv_cache=True + ) + self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + else: + tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) + self.inputs = (tokens, atten_mask) + self.llama_graph_module = llama_model + + def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): + if not self.has_quant_io: + return + + # shape of k caches and v caches + kv_cache_shape = { + # single head, kv input + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + # single head, kv output + (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), + } + io_shape = { + # logit output + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_vocab_size"], + ), + } + + atten_mask_shape = { + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_max_seq_len"], + ), + } + + freq_shape = { + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"] // 2), + } + + freq_op = { + exir_ops.edge.aten.select.int, + } + + for n in gm.graph.nodes: + if n.op == "placeholder": + if ( + len(users := list(n.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] + elif n.meta["val"].size() in io_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + elif n.meta["val"].size() in atten_mask_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + elif n.op == "output": + for a in n.args[0]: + if a.meta["val"].size()[-2:] in kv_cache_shape: + a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] + elif a.meta["val"].size() in io_shape: + a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + quant_attrs = a.meta["quant_attrs"] + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(n.users.keys()) + ] + [n.target]: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if n.target in freq_op and n.meta["val"].size() in freq_shape: + n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] + + return quant_attrs + + def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + fx_graph_module = None + + with torch.no_grad(): + fx_graph_module = torch.export.export( + self.llama_graph_module, self.inputs, strict=True + ).module() + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + + logging.info("Quantizing the model...") + calibrate( + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), + args.prompt, + fx_graph_module, + tokenizer=tokenizer, + ar_len=self.llama_meta["get_ar_len"], + max_seq_len=self.llama_meta["get_max_seq_len"], + kv_updater=args.kv_updater, + use_i64_token=args.embedding_quantize is not None, + ) + + self.llama_graph_module = convert_pt2e(fx_graph_module) + + def lowering_modules( + self, + work_space, + fixed_point_type, + use_fp16=False, + soc_model=QcomChipset.SM8650, + num_sharding=1, + passes_job=OrderedDict(), + shared_buffer=False, + verbose=False, + ): + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + with torch.no_grad(): + # backend option + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=num_sharding > 1 + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + shared_buffer=shared_buffer, + ) + skip_node_op_set = {"llama.fallback.default"} + partitioner = QnnPartitioner( + compiler_specs, skip_node_op_set=skip_node_op_set + ) + edge_prog = capture_program( + self.llama_graph_module, + self.inputs, + passes_job, + ) + + if num_sharding > 1: + model_sharding.split_graph( + edge_prog.exported_program, + self.llama_meta["get_n_layers"], + shares=num_sharding, + ) + + self.quant_attrs = self._tag_ios( + edge_prog.exported_program.graph_module, + fixed_point_type=fixed_point_type, + ) + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=self.llama_meta, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) + if num_sharding > 1: + update_spill_fill_size(edge_prog_mgr.exported_program()) + + if verbose: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) + + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + def get_example_inputs(self, use_kv_cache=True): + return self.llama_model.get_example_inputs(use_kv_cache) + + def get_quant_attrs(self): + return self.quant_attrs + + +def compile(args, pte_filename, tokenizer): + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + + config = BitNetConfig.from_pretrained(args.model_dir) + + state_dict = load_file(os.path.join(args.model_dir, "model.safetensors")) + + llama_instance_list = [] + use_i64_token = args.embedding_quantize is not None + with torch.device("meta"): + if args.model_mode == "kv": + llama_instance_list.append( + BitNetForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + elif args.model_mode == "hybrid": + llama_instance_list.append( + BitNetForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + llama_instance_list.append( + BitNetForCausalLM( + config, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + layer: BitNetDecoderLayer + convert_linear_to_bitlinear(layer.self_attn) + convert_linear_to_bitlinear(layer.mlp) + + for llama_instance in llama_instance_list: + incompatible_keys = llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + assert len(incompatible_keys.missing_keys) <= 1 and len(incompatible_keys.unexpected_keys) == 0 + llama_instance.tie_weights() + end_load_ts = time.time() + logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") + + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + if args.use_tman: + # TODO + layer.self_attn.prepare_tman() + else: + convert_bitlinear_to_linear(layer.self_attn) + layer.self_attn.prepare_sha() + convert_bitlinear_to_linear(layer.mlp) + + use_fp16 = True + fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} + if args.ptq: + use_fp16 = False + fixed_point_type["kv_type"] = torch.uint8 + if args.ptq == "8a8w": + fixed_point_type["io_type"] = torch.uint8 + elif args.ptq == "16a4w": + fixed_point_type["io_type"] = torch.uint16 + else: + assert args.ptq in [ + "8a8w", + "16a4w", + ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + passes_job = get_capture_program_passes() + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + for i in range(len(llama_instance_list)): + llama_instance_list[i] = llama_instance_list[i].to( + dtype_override.to_torch_dtype() + ) + + for i in range(len(llama_instance_list)): + if args.embedding_quantize: + llama_instance_list[i] = get_quant_embedding_transform(args)( + llama_instance_list[i] + ) + passes_job[ConstantI64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "skip_node" + ] = {"tokens"} + llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + print(llama_instance_list[i]) + llama_instance_list[i] = SingleLlama( + llama_instance_list[i].eval(), pte_filename + ) + + if args.ptq: + start_quantize_ts = time.time() + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": + custom_annotations = custom_annotations + ( + annotate_linear_16a8w_in_affine_layer, + ) + if args.ptq != None: + kv_quant_attrs = {} + for i, llama_instance in enumerate(llama_instance_list): + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + tokenizer=tokenizer, + custom_annotations=custom_annotations, + ) + # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode == "hybrid": + output_indices = 0 + for node in llama_instance.llama_graph_module.graph.nodes: + if node.op == "output": + for output in node.args[0]: + kv_quant_attrs[output_indices] = output.args[1:] + output_indices += 1 + break + custom_annotations = custom_annotations + ( + partial( + annotate_prefill_kv_output, + kv_quant_attrs=kv_quant_attrs, + ), + ) + end_quantize_ts = time.time() + logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") + + start_lowering_ts = time.time() + quant_attrs = None + + if args.model_mode in ["kv"]: + llama_instance_list[0].lowering_modules( + args.artifact, + fixed_point_type, + use_fp16=use_fp16, + soc_model=get_soc_to_chipset_map()[args.model], + num_sharding=args.num_sharding, + passes_job=passes_job, + shared_buffer=args.shared_buffer, + ) + quant_attrs = llama_instance_list[0].get_quant_attrs() + elif args.model_mode == "hybrid": + sample_inputs_list = [ + llama_instace.inputs for llama_instace in llama_instance_list + ] + edge_progs = [ + capture_program( + llama_instance.llama_graph_module, + sample_input, + passes_job=passes_job, + ) + for llama_instance, sample_input in zip( + llama_instance_list, sample_inputs_list + ) + ] + + if args.num_sharding > 1: + for i in range(len(llama_instance_list)): + model_sharding.split_graph( + edge_progs[i].exported_program, + llama_instance_list[i].llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + + for i in range(len(llama_instance_list)): + quant_attrs = llama_instance_list[i]._tag_ios( + edge_progs[i].exported_program.graph_module, + fixed_point_type, + ) + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 + ) + graph_names = ["kv_forward", "prefill_forward"] + compiler_specs = [ + generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.model], + backend_options=backend_options, + shared_buffer=args.shared_buffer, + multiple_graphs=True, + weight_sharing=not args.enable_x86_64, # x86 emulator does not support weight sharing + graph_name=graph_name, + ) + for graph_name in graph_names + ] + skip_node_op_set = {"llama.fallback.default"} + exported_programs = [ + to_backend( + edge_prog.exported_program, + QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set), + ) + for i, edge_prog in enumerate(edge_progs) + ] + if args.num_sharding > 1: + max_sf_size = update_spill_fill_size(exported_programs) + qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) + qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( + max_sf_size + ) + compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) + + if args.verbose: + for exported_program in exported_programs: + print_delegation_info(exported_program.graph_module) + + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + + bundle_progs_list = [] + lower_module_dict = {name: [] for name in graph_names} + call_delegate_inputs_dict = {name: [] for name in graph_names} + call_delegate_node_name_dict = {name: [] for name in graph_names} + outputs_dict = {name: [] for name in graph_names} + input_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(exported_programs, graph_names): + for node in prog.graph_module.graph.nodes: + if ( + node.op == "call_function" + and "executorch_call_delegate" in node.name + ): + call_delegate_node_name_dict[graph_name].append(node.name) + call_delegate_inputs_list = [] + for arg in node.args: + if arg.op == "call_function": + if ( + arg.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ): + call_delegate_inputs_list.append((arg.name, None)) + else: + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) + elif arg.op == "placeholder": + call_delegate_inputs_list.append((arg.name, None)) + # No extra needs to do for get_attr node + call_delegate_inputs_dict[graph_name].append( + call_delegate_inputs_list + ) + elif node.op == "output": + for arg in node.args[0]: + outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) + for num in range(args.num_sharding - 1, -1, -1): + processed_bytes = [] + for prog, graph_name in zip(exported_programs, graph_names): + processed_bytes.append( + getattr(prog.graph_module, f"lowered_module_{num}").processed_bytes + ) + call_delegate_node = [ + list(node.users.keys())[0] + for node in prog.graph_module.graph.nodes + if node.op == "get_attr" and node.name == f"lowered_module_{num}" + ] + input_nodes_dict[graph_name] = [ + node + for node in call_delegate_node[0].args + if node.op == "placeholder" + or node.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ] + prog_mgr, bundle_progs = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[0].llama_meta, # kv method meta + ) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + exec_prog = generate_composite_llama_program( + llama_model=llama_instance_list[1].llama_model, + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + embedding_quantize=args.embedding_quantize, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) + + end_lowering_ts = time.time() + logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") + return quant_attrs + + +def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + + if args.model_mode == "kv": + eval_mode = 0 + elif args.model_mode == "hybrid": + eval_mode = 1 + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + pte_path = ( + f"{pre_gen_pte}/{pte_filename}.pte" + if pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + with open(f"{args.artifact}/outputs/outputs.txt", "r") as f: + outputs.append(f.read()) + + seq_len = args.max_seq_len + runner_args = " ".join( + [ + f'--prompt "{args.prompt}"', + f"--eval_mode {eval_mode}", + f"--temperature {args.temperature}", + f"--system_prompt '{args.system_prompt}'", + f"--logits_scale {quant_attrs['scale']}", + f"--logits_offset {quant_attrs['zero_point']}", + ] + ) + + runner_cmd = "" + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. Check only the first few tokens. + seq_len = min(seq_len, 16) + + if args.kv_updater == smart_mask_updater: + logging.warning( + "x86 only support ShiftPointer, overwrite kv_updater to ShiftPointer" + ) + + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + f"--tokenizer_path {runtime_tokenizer_path}", + f"--model_path {pte_path}", + f"--seq_len {seq_len}", + f"--output_path {args.artifact}/outputs/outputs.txt", + f"--kv_updater ShiftPointer", + runner_args, + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + f"./qnn_llama_runner", + f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}", + f"--model_path {pte_filename}.pte", + f"--seq_len {seq_len}", + "--output_path outputs/outputs.txt", + f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", + runner_args, + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path]) + adb.execute(custom_runner_cmd=runner_cmd) + + adb.pull(output_path=args.artifact, callback=post_process) + if args.ip and args.port != -1: + inference_speed = 0 + with open(f"{args.artifact}/outputs/inference_speed.txt", "r") as f: + inference_speed = float(f.read()) + + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "result": outputs, + "pte_size": pte_size, + "inference_speed": inference_speed, + } + ) + ) + else: + for idx, output in enumerate(outputs): + logging.info(f"Results[{idx}]:\n{output}") + + +def _build_parser(): + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./llama_qnn", + default="./llama_qnn", + type=str, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + type=str, + ) + + parser.add_argument( + "--llama_model", + choices=["stories110m", "llama3_2", "bitnet"], + help="The Llama model to export. Current available options are: [stories110m, llama3_2]", + required=True, + ) + + parser.add_argument( + "--model_dir", + help="Pass llama checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="For Llama2. Pass Llama2 tokenizer binary.", + required=False, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama.", + required=True, + type=str, + ) + + parser.add_argument( + "--system_prompt", + help="For Llama3. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", + default="", + type=str, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated llama in the given directory.", + type=str, + ) + + parser.add_argument( + "--num_sharding", + type=int, + default=1, + help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", + ) + + parser.add_argument( + "--model_mode", + help="Export and inference kv mode or hybrid mode", + default="kv", + choices=["kv", "hybrid"], + type=str, + ) + + parser.add_argument( + "--max_seq_len", + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + default=512, + type=int, + ) + + parser.add_argument( + "--prefill_ar_len", + help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.", + default=32, + type=int, + ) + + parser.add_argument( + "--kv_updater", + help="Choose how to update kv cache during runtime", + choices=["smart_mask", "shift_pointer"], + default="smart_mask", + type=str, + ) + + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", + ) + + parser.add_argument( + "--use_tman", + action="store_true", + help="Use TMANLinear instead of QNNConv2d.", + ) + + parser.add_argument("-v", "--verbose", action="store_true") + + return parser + + +def export_llama(args) -> None: + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + if args.model_mode == "kv": + pte_filename = "kv_llama_qnn" + elif args.model_mode == "hybrid": + assert ( + args.max_seq_len >= args.prefill_ar_len + ), "Please ensure max_seq_len is >= prefill_ar_len" + pte_filename = "hybrid_llama_qnn" + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + tokenizer = get_tokenizer(args.tokenizer_model) + runtime_tokenizer_path = "" + if args.llama_model == "stories110m": + assert isinstance( + tokenizer, SentencePieceTokenizer + ), f"Wrong tokenizer provided for stories110m." + assert ( + args.tokenizer_bin is not None + ), "Please provide tokenizer_bin for stories110m." + runtime_tokenizer_path = args.tokenizer_bin + elif args.llama_model == "llama3_2": + assert isinstance( + tokenizer, Tiktoken + ), f"Wrong tokenizer provided for llama3_2." + runtime_tokenizer_path = args.tokenizer_model + elif args.llama_model == "bitnet": + assert isinstance( + tokenizer, HuggingFaceTokenizer + ), f"Wrong tokenizer provided for bitnet." + runtime_tokenizer_path = args.tokenizer_model + else: + raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") + + if args.kv_updater == "smart_mask": + args.shared_buffer = True + args.kv_updater = smart_mask_updater + elif args.kv_updater == "shift_pointer": + args.kv_updater = shift_pointer_updater + else: + exit(f"Using an unkown kv update {args.kv_updater}") + + if args.pre_gen_pte: + quant_attrs = json.load( + open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt") + ) + inference( + args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte + ) + exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") + + if args.compile_only: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + + if args.ip and args.port != -1: + pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "pte_size": pte_size, + } + ) + ) + exit(f"Finish compile_only and save to {args.artifact}") + + try: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + logging.info( + f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}" + ) + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + inference(args, quant_attrs, pte_filename, runtime_tokenizer_path) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) + + +def main(): + parser = _build_parser() + args = parser.parse_args() + export_llama(args) + + +# flake8: noqa: C901 +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/bitnet/model/__init__.py b/examples/qualcomm/oss_scripts/bitnet/model/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/qualcomm/oss_scripts/bitnet/model/configuration_bitnet.py b/examples/qualcomm/oss_scripts/bitnet/model/configuration_bitnet.py new file mode 100644 index 00000000000..adcdd731436 --- /dev/null +++ b/examples/qualcomm/oss_scripts/bitnet/model/configuration_bitnet.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +"""BitNet model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class BitNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`BitNetModel`]. It is used to instantiate an BitNet + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of + BitNet b1.58 2B4T [microsoft/bitnet-b1.58-2B-4T](https://huggingface.co/microsoft/bitnet-b1.58-2B-4T). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 128256): + Vocabulary size of the BitNet model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`BitNetModel`] + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 6912): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 20): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 5): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 128000): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 128001): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 500000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import BitNetModel, BitNetConfig + + >>> # Initializing a BitNet style configuration + >>> configuration = BitNetConfig() + + >>> # Initializing a model from the BitNet style configuration + >>> model = BitNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "bitnet" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128256, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=30, + num_attention_heads=20, + num_key_value_heads=5, + hidden_act="relu2", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=128000, + eos_token_id=128001, + tie_word_embeddings=False, + rope_theta=500000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["BitNetConfig"] diff --git a/examples/qualcomm/oss_scripts/bitnet/model/static_bitnet.py b/examples/qualcomm/oss_scripts/bitnet/model/static_bitnet.py new file mode 100644 index 00000000000..f330137c60d --- /dev/null +++ b/examples/qualcomm/oss_scripts/bitnet/model/static_bitnet.py @@ -0,0 +1,574 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/bitnet/modular_bitnet.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_bitnet.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from typing import Callable, Optional, Tuple, Union, List + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.utils import logging +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel + +from .configuration_bitnet import BitNetConfig +from executorch.examples.models.llama.rope import precompute_freqs_cis + + +logger = logging.get_logger(__name__) + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # broadcast for batch_prefill mode input x + if x.dim() == 4: + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class BitNetMLP(nn.Module): + def __init__(self, config: BitNetConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_sub_norm = torch.nn.RMSNorm(config.intermediate_size, eps=config.rms_norm_eps) + + def forward(self, x): + down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x))) # diff with Llama + return down_proj + + +class BitNetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitNetConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.dim = config.hidden_size + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.n_heads = config.num_attention_heads + self.n_kv_heads = config.num_key_value_heads + self.output_new_cache_only = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_sub_norm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn_softmax = torch.nn.Softmax(dim=-1) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + + output_kh, output_vh, output_y = [], [], [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + # batch_prefill mode + else: + kh = k + vh = v + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[:, cache_idx, :, :] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[:, :, cache_idx, :] + + output_y.append(y) + + for i in range(self.n_kv_heads): + if self.output_new_cache_only: + output_kh.append(k[:, i, :, -1]) + output_vh.append(v[:, -1, i, :]) + else: + output_kh.append(k[:, i, :, :]) + output_vh.append(v[:, :, i, :]) + + y = torch.concat(output_y, dim=-1) + y = self.attn_sub_norm(y) # diff with Llama + y = self.o_proj(y) + + return y, output_kh, output_vh + + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False) + + self.forward_mha = self.forward + self.forward = self.forward_sha + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.q_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + for i in range(self.n_kv_heads): + self.wk_sha[i].weight.data.copy_( + self.k_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wv_sha[i].weight.data.copy_( + self.v_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wo_sha.weight.data.copy_(self.o_proj.weight[:, :, None, None]) + + def prepare_tman(self): + self.forward_mha = self.forward + self.forward = self.forward_tman + + def forward_tman( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states).reshape(bsz, seq_len, self.n_heads, self.head_dim) + k = self.k_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = self.v_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + # Use split_with_sizes as only split_with_sizes is supported + q = torch.split_with_sizes(q, [1 for _ in range(self.n_heads)], dim=2) + k = torch.split_with_sizes(k, [1 for _ in range(self.n_kv_heads)], dim=1) + v = torch.split_with_sizes(v, [1 for _ in range(self.n_kv_heads)], dim=2) + q = [t.squeeze(2) for t in q] + k = [t.squeeze(1) for t in k] + v = [t.squeeze(2) for t in v] + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.attn_sub_norm(y) + y = self.o_proj(y) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + hidden_states = torch.reshape( + hidden_states, (bsz, seq_len, 1, self.dim) + ).transpose(1, 3) + q = [ + wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wq_sha in self.wq_sha + ] + k = [ + wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wk_sha in self.wk_sha + ] + v = [ + wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wv_sha in self.wv_sha + ] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.attn_sub_norm(y) # diff with Llama + y = y.reshape(bsz, seq_len, 1, -1) + y = y.transpose(1, 3) + y = self.wo_sha(y) + y = y.transpose(1, 3) + y = y.reshape(bsz, seq_len, -1) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + +class BitNetDecoderLayer(nn.Module): + def __init__(self, config: BitNetConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = BitNetAttention(config=config, layer_idx=layer_idx) + + self.mlp = BitNetMLP(config) + self.input_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, k_cache, v_cache = self.self_attn( + hidden_states=self.input_layernorm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.mlp(self.post_attention_layernorm(h)) + return output, k_cache, v_cache + + +class BitNetModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BitNetDecoderLayer`] + + Args: + config: BitNetConfig + """ + + def __init__(self, config: BitNetConfig, max_seq_len: int): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [BitNetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_seq_len = max_seq_len + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.use_scaled_rope = True + self.rope_scale_factor = config.rope_scaling["factor"] + else: + self.use_scaled_rope = False + self.rope_scale_factor = None + freqs_cos, freqs_sin = precompute_freqs_cis( + self.head_dim, + self.max_seq_len, + config.rope_theta, + self.use_scaled_rope, + self.rope_scale_factor, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + self.use_kv_cache = True + self.n_layers = config.num_hidden_layers + self.n_kv_heads = config.num_key_value_heads + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = ( + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos + ) + freqs_sin = ( + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin + ) + + hidden_states = self.embed_tokens(tokens) + for ind, decoder_layer in enumerate(self.layers): + k_caches = None + v_caches = None + if self.use_kv_cache: + offset_k = ind * self.n_kv_heads + offset_v = self.n_layers * self.n_kv_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_kv_heads] + v_caches = args[offset_v : offset_v + self.n_kv_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + + return hidden_states, output_k_cache, output_v_cache + + +class BitNetForCausalLM(PreTrainedModel): + + def __init__( + self, + config: BitNetConfig, + ar_len: int = 1, + max_seq_len: int = 128, + use_i64_token: bool = False + ): + super().__init__(config) + self.model = BitNetModel(config, max_seq_len) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.ar_len = ar_len + self.bos_id = config.bos_token_id + self.eos_id = config.eos_token_id + self.dim = config.hidden_size + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_batch_size = 1 + self.max_seq_len = max_seq_len + self.n_kv_heads = config.num_key_value_heads + self.n_layers = config.num_hidden_layers + self.use_kv_cache = True + self.use_i64_token = use_i64_token + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + hidden_states, output_k_cache, output_v_cache = self.model(tokens, atten_mask, input_pos, *args) + logits = self.lm_head(hidden_states) + + return logits, output_k_cache, output_v_cache + + def get_example_inputs(self, use_kv_cache=True): + dtype = torch.int64 if self.use_i64_token else torch.int32 + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype + ) + + atten_mask = torch.full((self.ar_len, self.ar_len), torch.tensor(-255.0)) + mask_cond = torch.arange(atten_mask.size(-1)) + atten_mask.masked_fill_( + mask_cond < (mask_cond + 1).view(atten_mask.size(-1), 1), 0 + ) + if self.max_seq_len != self.ar_len: + atten_mask = torch.cat( + [ + torch.ones(self.ar_len, self.max_seq_len - self.ar_len) * -255.0, + atten_mask, + ], + dim=-1, + ) + atten_mask = atten_mask[None, :, :].expand( + self.max_batch_size, self.ar_len, self.max_seq_len + ) + if use_kv_cache: + pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) + k_cache, v_cache = [], [] + + for _ in range(self.n_layers): + for _ in range(self.n_kv_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - self.ar_len, + ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - self.ar_len, + self.head_dim, + ) + ) + return ( + tokens, + atten_mask, + pos_ids, + k_cache, + v_cache, + ) + + return ( + tokens, + atten_mask, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_ar_len": self.ar_len, + "get_bos_id": self.bos_id, + "get_eos_id": self.eos_id, + "get_dim": self.dim, + "get_head_dim": self.head_dim, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_kv_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + "get_use_kv_cache": self.use_kv_cache, + } + + def get_output_embeddings(self): + return self.lm_head + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/examples/qualcomm/oss_scripts/llama/convert_gptq_weights_to_llama.py b/examples/qualcomm/oss_scripts/llama/convert_gptq_weights_to_llama.py new file mode 100644 index 00000000000..9e53d81e349 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/convert_gptq_weights_to_llama.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import os +from typing import List, Union + +import torch +from tqdm import tqdm +from gptqmodel import GPTQModel +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.utils.backend import BACKEND +import argparse +from transformers import AutoTokenizer + +NUM_SHARDS = { + "1B": 1, + "7B": 1, + "8B": 1, + "13B": 2, + "34B": 4, + "30B": 4, + "65B": 8, + "70B": 8, +} + + +def write_model(model_path, model_size, output_base_path): + dtype = torch.bfloat16 + + params = json.load(open(os.path.join(output_base_path, "params.json"), "r")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + llama_version = 3 if params.get("vocab_size") == 128256 else 2 + + if "n_kv_heads" in params: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = num_key_value_heads // num_shards + key_value_dim = dims_per_head * num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + # instead of load state_dict directly + # load state_dict from gptqmodel + # to deal with GPTQ v1 -> GPTQ v2 transformation + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = GPTQModel.from_quantized( + model_path, + low_cpu_mem_usage=True, + backend=BACKEND.TORCH, + ) + messages = [ + {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."}, + {"role": "user", "content": "How can I design a data structure in C++ to store the top 5 largest integer numbers?"}, + ] + # input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") + # outputs = model.generate(input_ids=input_tensor.to(model.device), max_new_tokens=32) + # result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True) + # print(result) + + loaded = model.model.state_dict() + + # permute for sliced rotary + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + state_dict = [{} for _ in range(num_shards)] + + def insert(name: str, tensor: Union[List, torch.Tensor]): + for i in range(num_shards): + state_dict[i][name] = ( + tensor[i].clone() if isinstance(tensor, list) else tensor + ) + + def insert_chunk(name: str, tensor: torch.Tensor, dim: int): + tensors = tensor.chunk(num_shards, dim=dim) + for i, tensor in enumerate(tensors): + state_dict[i][name] = tensor.clone() + + def insert_quantized(name: str, original_name: str): + insert( + f"{name}.qweight", + loaded[f"{original_name}.qweight"], + ) + insert( + f"{name}.scales", + loaded[f"{original_name}.scales"], + ) + insert( + f"{name}.qzeros", + loaded[f"{original_name}.qzeros"], + ) + insert( + f"{name}.g_idx", + loaded[f"{original_name}.g_idx"], + ) + + concat_dim = 0 if llama_version == 3 else 1 + insert_chunk( + "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim + ) + insert("norm.weight", loaded["model.norm.weight"]) + insert_chunk("output.weight", loaded["lm_head.weight"], 0) + + for layer_i in tqdm(range(n_layers), desc="Converting layers"): + + # deal with hf permute in static_llama.py as it's hard to permute quantized weights + insert_quantized( + f"layers.{layer_i}.attention.wq", + f"model.layers.{layer_i}.self_attn.q_proj", + ) + insert_quantized( + f"layers.{layer_i}.attention.wk", + f"model.layers.{layer_i}.self_attn.k_proj", + ) + insert_quantized( + f"layers.{layer_i}.attention.wv", + f"model.layers.{layer_i}.self_attn.v_proj", + ) + insert_quantized( + f"layers.{layer_i}.attention.wo", + f"model.layers.{layer_i}.self_attn.o_proj", + ) + insert_quantized( + f"layers.{layer_i}.feed_forward.w1", + f"model.layers.{layer_i}.mlp.gate_proj", + ) + insert_quantized( + f"layers.{layer_i}.feed_forward.w2", + f"model.layers.{layer_i}.mlp.down_proj", + ) + insert_quantized( + f"layers.{layer_i}.feed_forward.w3", + f"model.layers.{layer_i}.mlp.up_proj", + ) + insert( + f"layers.{layer_i}.attention_norm.weight", + loaded[f"model.layers.{layer_i}.input_layernorm.weight"], + ) + insert( + f"layers.{layer_i}.ffn_norm.weight", + loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"], + ) + if llama_version != 3: + base = 10000.0 + inv_freq = ( + 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + ).to(dtype) + insert("rope.freqs", inv_freq) + + for i in tqdm(range(num_shards), desc="Saving checkpoint shards"): + torch.save( + state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth") + ) + + +def main( + model_path: str, + model_size: str, + output_dir: str, +): + """Convert llama weights from huggingface format to consolidated format. + params: + model_path: model name or path to the model directory. + model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B. + output_dir: directory to save Llama weights, should contains params.json. + """ + assert model_size in NUM_SHARDS, f"Unknown model size {model_size}" + params_path = os.path.join(output_dir, "params.json") + assert os.path.isfile(params_path), f"{params_path} does not exist" + + write_model(model_path, model_size, output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model-path", + type=str, + required=True, + help="Path to the model directory.", + ) + parser.add_argument( + "-s", + "--model-size", + type=str, + required=True, + choices=NUM_SHARDS.keys(), + help="Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + required=True, + help="Directory to save Llama weights, should contains params.json.", + ) + args = parser.parse_args() + main( + model_path=args.model_path, + model_size=args.model_size, + output_dir=args.output_dir, + ) diff --git a/examples/qualcomm/oss_scripts/llama/convert_hf_weights_to_llama.py b/examples/qualcomm/oss_scripts/llama/convert_hf_weights_to_llama.py new file mode 100644 index 00000000000..45cbc73cd9b --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/convert_hf_weights_to_llama.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +import os +from typing import List, Union + +import fire +import torch +from tqdm import tqdm +from transformers import LlamaForCausalLM # @manual + +NUM_SHARDS = { + "1B": 1, + "7B": 1, + "8B": 1, + "13B": 2, + "34B": 4, + "30B": 4, + "65B": 8, + "70B": 8, +} + + +def write_model(model_path, model_size, output_base_path): + dtype = torch.bfloat16 + + params = json.load(open(os.path.join(output_base_path, "params.json"), "r")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + llama_version = 3 if params.get("vocab_size") == 128256 else 2 + + if "n_kv_heads" in params: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = num_key_value_heads // num_shards + key_value_dim = dims_per_head * num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + model = LlamaForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + low_cpu_mem_usage=True, + ) + loaded = model.state_dict() + + # permute for sliced rotary + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + state_dict = [{} for _ in range(num_shards)] + + def insert(name: str, tensor: Union[List, torch.Tensor]): + for i in range(num_shards): + state_dict[i][name] = ( + tensor[i].clone() if isinstance(tensor, list) else tensor + ) + + def insert_chunk(name: str, tensor: torch.Tensor, dim: int): + tensors = tensor.chunk(num_shards, dim=dim) + for i, tensor in enumerate(tensors): + state_dict[i][name] = tensor.clone() + + concat_dim = 0 if llama_version == 3 else 1 + insert_chunk( + "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim + ) + insert("norm.weight", loaded["model.norm.weight"]) + insert_chunk("output.weight", loaded["lm_head.weight"], 0) + + for layer_i in tqdm(range(n_layers), desc="Converting layers"): + + ts = ( + permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) + .view(n_heads_per_shard * num_shards, dims_per_head, dim) + .chunk(num_shards, dim=0) + ) + insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts]) + + ts = ( + permute( + loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"], + num_key_value_heads, + key_value_dim, + dim, + ) + .view(num_local_key_value_heads * num_shards, dims_per_head, dim) + .chunk(num_shards, dim=0) + ) + insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts]) + + ts = ( + loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"] + .view(num_local_key_value_heads * num_shards, dims_per_head, dim) + .chunk(num_shards, dim=0) + ) + insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts]) + + insert_chunk( + f"layers.{layer_i}.attention.wo.weight", + loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"], + 1, + ) + + insert_chunk( + f"layers.{layer_i}.feed_forward.w1.weight", + loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"], + 0, + ) + + insert_chunk( + f"layers.{layer_i}.feed_forward.w2.weight", + loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"], + 1, + ) + + insert_chunk( + f"layers.{layer_i}.feed_forward.w3.weight", + loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"], + 0, + ) + + insert( + f"layers.{layer_i}.attention_norm.weight", + loaded[f"model.layers.{layer_i}.input_layernorm.weight"], + ) + insert( + f"layers.{layer_i}.ffn_norm.weight", + loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"], + ) + if llama_version != 3: + base = 10000.0 + inv_freq = ( + 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + ).to(dtype) + insert("rope.freqs", inv_freq) + + for i in tqdm(range(num_shards), desc="Saving checkpoint shards"): + torch.save( + state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth") + ) + + +def main( + model_path: str, + model_size: str, + output_dir: str, +): + """Convert llama weights from huggingface format to consolidated format. + params: + model_path: model name or path to the model directory. + model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B. + output_dir: directory to save Llama weights, should contains params.json. + """ + assert model_size in NUM_SHARDS, f"Unknown model size {model_size}" + params_path = os.path.join(output_dir, "params.json") + assert os.path.isfile(params_path), f"{params_path} does not exist" + + write_model(model_path, model_size, output_dir) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 375edf9fb6c..b75d9150de8 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -50,6 +50,9 @@ ) from executorch.backends.qualcomm.utils.utils import ( convert_linear_to_conv2d, + convert_qlinear_to_tman_linear, + convert_qlinear_to_linear, + convert_linear_to_qlinear, generate_composite_llama_program, generate_htp_compiler_spec, generate_multi_graph_program, @@ -66,6 +69,7 @@ from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( LlamaModel, ModelArgs, + LlamaDecoderLayer, ) from executorch.examples.qualcomm.utils import ( make_output_dir, @@ -285,6 +289,14 @@ def calibrate( raise RuntimeError("Get wrong inputs") +def permute(weights: torch.Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + class SingleLlama: def __init__(self, llama_model, pte_filename) -> None: super().__init__() @@ -534,6 +546,30 @@ def compile(args, pte_filename, tokenizer): else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + if args.gptq_dir: + from gptqmodel.quantization.config import QuantizeConfig + from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + qcfg = QuantizeConfig.from_pretrained(args.gptq_dir) + if qcfg.desc_act: + raise RuntimeError( + "desc_act=True is unsupported right now." + ) + qlinear_cls = partial( + TorchQuantLinear, + bits=qcfg.bits, + group_size=qcfg.group_size, + desc_act=qcfg.desc_act, + sym=qcfg.sym, + pack_dtype=qcfg.pack_dtype, + device=qcfg.device, + adapter=qcfg.adapter, + ) + for llama_instance in llama_instance_list: + layer: LlamaDecoderLayer + for layer in llama_instance.layers: + convert_linear_to_qlinear(layer.attention, qlinear_cls) + convert_linear_to_qlinear(layer.feed_forward, qlinear_cls) + if "model" in state_dict: state_dict = state_dict["model"] @@ -562,7 +598,7 @@ def permute(w, heads): for llama_instance in llama_instance_list: llama_instance.load_state_dict( state_dict, - strict=False, + strict=True, assign=True, ) end_load_ts = time.time() @@ -570,10 +606,19 @@ def permute(w, heads): for llama_instance in llama_instance_list: for layer in llama_instance.layers: - if getattr(layer.attention, "prepare_sha", None): - layer.attention.prepare_sha() - if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): - layer.feed_forward.prepare_feedfoward_conv() + if args.gptq_dir: + # TODO: optimize the performance when needed + if args.use_tman: + if getattr(layer.attention, "prepare_tman", None): + layer.attention.prepare_tman(do_permute=False, use_sha=False) + convert_qlinear_to_tman_linear(layer.feed_forward) + else: + convert_qlinear_to_linear(layer.attention) + if getattr(layer.attention, "prepare_sha", None): + layer.attention.prepare_sha() + convert_qlinear_to_linear(layer.feed_forward) + if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): + layer.feed_forward.prepare_feedfoward_conv() use_fp16 = True fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} @@ -607,6 +652,8 @@ def permute(w, heads): llama_instance_list[i] ) llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + # llama_instance_list[i] = convert_qlinear_to_tman_linear(llama_instance_list[i]) + print(llama_instance_list[i]) llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), pte_filename ) @@ -1083,6 +1130,19 @@ def _build_parser(): help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", ) + parser.add_argument( + "--gptq_dir", + default=None, + type=str, + help="Path to the GPTQ model dir, which should contain config.json or quantize_config.json.", + ) + + parser.add_argument( + "--use_tman", + action="store_true", + help="Use TMANLinear instead of QNNConv2d.", + ) + parser.add_argument("-v", "--verbose", action="store_true") return parser diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index f7893792e00..81056e7ca67 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -9,12 +9,15 @@ from typing import List, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import precompute_freqs_cis +from executorch.backends.qualcomm.utils.utils import TMANLinear + def apply_rotary_emb_single( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor @@ -34,6 +37,37 @@ def apply_rotary_emb_single( return x_out +def permute(qlinear: nn.Module, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + # Unpack the qzeros first for permutations + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qlinear.qzeros, 2).expand(-1, -1, qlinear.pack_factor), + qlinear.wf_unsqueeze_zero, + ).to(qlinear.dequant_dtype) + zeros = torch.bitwise_and(zeros, qlinear.maxq).reshape(qlinear.scales.shape) + + def _permute(weights: torch.Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(weights.shape[0], n_head, 2, weights.shape[1] // n_head // 2) + .swapaxes(-2, -1) + .reshape(weights.shape)) + + qweight = _permute(qlinear.qweight, n_head, n_head_kv) + scales = _permute(qlinear.scales, n_head, n_head_kv) + zeros = _permute(zeros, n_head, n_head_kv) + + zeros = zeros.numpy().astype(qlinear.pack_np_math_dtype) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // qlinear.pack_dtype_bits * qlinear.bits), dtype=qlinear.pack_np_math_dtype) + for col in range(qzeros.shape[1]): + for j in range(qlinear.pack_factor): + qzeros[:, col] |= zeros[:, col * qlinear.pack_factor + j] << (qlinear.bits * j) + qzeros = torch.from_numpy(qzeros.astype(qlinear.pack_np_dtype)) + + return qweight, scales, qzeros + + class LlamaAttention(nn.Module): def __init__(self, config: ModelArgs, output_new_cache_only=False): super().__init__() @@ -167,6 +201,155 @@ def forward_sha( return y, kh, vh + # TODO: can't find a better way to replace with tman_sha at this moment + # place it here for now + def prepare_tman(self, do_permute=False, use_sha=False): + from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + assert(isinstance(self.wq, TorchQuantLinear)) + assert(isinstance(self.wk, TorchQuantLinear)) + assert(isinstance(self.wv, TorchQuantLinear)) + assert(isinstance(self.wo, TorchQuantLinear)) + self.wq.post_init() + self.wk.post_init() + self.wv.post_init() + self.wo.post_init() + + if use_sha: + self.wq_tman = nn.ModuleList( + [ + TMANLinear(self.wq, n_splits=self.n_heads) + for _ in range(self.n_heads) + ] + ) + self.wk_tman = nn.ModuleList( + [ + TMANLinear(self.wk, n_splits=self.n_kv_heads) + for _ in range(self.n_kv_heads) + ] + ) + self.wv_tman = nn.ModuleList( + [ + TMANLinear(self.wv, n_splits=self.n_kv_heads) + for _ in range(self.n_kv_heads) + ] + ) + else: + self.wq_tman = TMANLinear(self.wq) + self.wk_tman = TMANLinear(self.wk) + self.wv_tman = TMANLinear(self.wv) + self.wo_tman = TMANLinear(self.wo) + + self.forward_mha = self.forward + self.forward = self.forward_tman + + if do_permute: + wq_qweight, wq_scales, wq_qzeros = permute(self.wq, self.n_heads, self.n_heads) + wk_qweight, wk_scales, wk_qzeros = permute(self.wk, self.n_heads, self.n_kv_heads) + else: + wq_qweight, wq_scales, wq_qzeros = self.wq.qweight, self.wq.scales, self.wq.qzeros + wk_qweight, wk_scales, wk_qzeros = self.wk.qweight, self.wk.scales, self.wk.qzeros + wv_qweight, wv_scales, wv_qzeros = self.wv.qweight, self.wv.scales, self.wv.qzeros + + def _copy(target: nn.ModuleList | TMANLinear, source: TorchQuantLinear, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, n_head: int, n_head_kv: int | None): + if use_sha: + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + for i in range(n_head): + target[i].qweight.copy_( + qweight[ + :, i * self.head_dim : (i + 1) * self.head_dim + ] + ) + target[i].scales.copy_( + scales[ + :, i * self.head_dim : (i + 1) * self.head_dim + ] + ) + target[i].qzeros.copy_( + qzeros[ + :, i * self.head_dim // source.pack_factor : (i + 1) * self.head_dim // source.pack_factor + ] + ) + else: + target.qweight.copy_(qweight) + target.scales.copy_(scales) + target.qzeros.copy_(qzeros) + + _copy(self.wq_tman, self.wq, wq_qweight, wq_scales, wq_qzeros, self.n_heads, self.n_heads) + _copy(self.wk_tman, self.wk, wk_qweight, wk_scales, wk_qzeros, self.n_heads, self.n_kv_heads) + _copy(self.wv_tman, self.wv, wv_qweight, wv_scales, wv_qzeros, self.n_heads, self.n_kv_heads) + + def forward_tman( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + if isinstance(self.wq_tman, nn.ModuleList): + q = [ + wq_tman(hidden_states) + for wq_tman in self.wq_tman + ] + k = [ + wk_tman(hidden_states) + for wk_tman in self.wk_tman + ] + v = [ + wv_tman(hidden_states) + for wv_tman in self.wv_tman + ] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + else: + q = self.wq_tman(hidden_states).reshape(bsz, seq_len, self.n_heads, self.head_dim) + k = self.wk_tman(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = self.wv_tman(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + # Use split_with_sizes as only split_with_sizes is supported + q = torch.split_with_sizes(q, [1 for _ in range(self.n_heads)], dim=2) + k = torch.split_with_sizes(k, [1 for _ in range(self.n_kv_heads)], dim=1) + v = torch.split_with_sizes(v, [1 for _ in range(self.n_kv_heads)], dim=2) + q = [t.squeeze(2) for t in q] + k = [t.squeeze(1) for t in k] + v = [t.squeeze(2) for t in v] + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn / self.scale + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo_tman(y) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + def forward( self, hidden_states: torch.Tensor, diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index f23cf2ec44a..7457db88769 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -58,6 +58,10 @@ DEFINE_string( "How to update kv cache. Choose between SmartMask and ShiftPointer", "SmartMask"); DEFINE_int32(num_iters, 1, "total num of iterations to run."); +DEFINE_string( + kv_type, + "Type of kv cache. Choose between uint8 and float32", + "uint8"); std::vector CollectPrompts(int argc, char** argv) { // Collect all prompts from command line, example usage: @@ -85,7 +89,8 @@ int main(int argc, char** argv) { FLAGS_temperature, FLAGS_eval_mode, FLAGS_kv_updater, - FLAGS_num_iters); + FLAGS_num_iters, + FLAGS_kv_type); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); diff --git a/examples/qualcomm/oss_scripts/llama/runner/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/runner/CMakeLists.txt new file mode 100644 index 00000000000..373cbeb60f8 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/runner/CMakeLists.txt @@ -0,0 +1,79 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# adapted from examples/models/llama/runner/CMakeLists.txt + +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake) + +# +# The `__srcs` lists are defined by including ${EXECUTORCH_SRCS_FILE}. +# +set(EXECUTORCH_SRCS_FILE + "${CMAKE_CURRENT_BINARY_DIR}/../../../../../executorch_srcs.cmake" +) + +extract_sources(${EXECUTORCH_SRCS_FILE}) + +include(${EXECUTORCH_SRCS_FILE}) + +list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_ROOT}/") + +target_include_directories( + extension_module INTERFACE ${_common_include_directories} +) + +list( + PREPEND + _llama_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner.h + ${CMAKE_CURRENT_LIST_DIR}/io_manager.cpp + ${CMAKE_CURRENT_LIST_DIR}/io_manager.h +) + +# build qnn llama runner +if(CMAKE_TOOLCHAIN_IOS + OR ANDROID + OR APPLE +) + # Building a share library on iOS requires code signing On Android we see + # duplicated registration when using shared lib + add_library(llama_runner STATIC ${_llama_runner__srcs}) +else() + add_library(llama_runner SHARED ${_llama_runner__srcs}) +endif() + +set(llama_runner_deps executorch extension_data_loader extension_module + extension_tensor qnn_executorch_backend +) + +target_link_libraries(llama_runner PUBLIC ${llama_runner_deps}) + +target_include_directories( + llama_runner + INTERFACE ${_common_include_directories} +) + +# Include tokenizers dependency +set(CMAKE_POSITION_INDEPENDENT_CODE ON) +add_subdirectory( + ${EXECUTORCH_ROOT}/extension/llm/tokenizers + ${CMAKE_CURRENT_BINARY_DIR}/tokenizers +) +target_link_libraries( + llama_runner PUBLIC tokenizers +) + +target_include_directories( + llama_runner + PUBLIC ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include +) +target_compile_options(llama_runner PUBLIC ${_preprocessor_flag}) diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index c2bf7b04fbb..9215ddc745b 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -705,7 +705,8 @@ void ShiftPointerIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { ptr->kv_attention_mask[kv_cache_len_] = 65535; } -SmartMaskIoMgr::SmartMaskIoMgr( +template +SmartMaskIoMgr::SmartMaskIoMgr( std::vector>& modules, int32_t context_len, int32_t prefill_ar_len, @@ -768,7 +769,8 @@ SmartMaskIoMgr::SmartMaskIoMgr( new IO, [](void* ptr) { delete static_cast(ptr); }); } -std::unordered_map SmartMaskIoMgr::get_io_elements() { +template +std::unordered_map SmartMaskIoMgr::get_io_elements() { int32_t max_ar_len = std::max(kv_ar_len_, prefill_ar_len_); size_t cache_in_ele = num_layers_ * num_heads_ * head_dim_ * kv_cache_len_; size_t cache_out_ele = num_layers_ * num_heads_ * head_dim_ * max_ar_len; @@ -785,7 +787,8 @@ std::unordered_map SmartMaskIoMgr::get_io_elements() { {"prefill_logits_ele", prefill_ar_len_ * vocab_size_}}; } -std::unordered_map SmartMaskIoMgr::get_io_bytes() { +template +std::unordered_map SmartMaskIoMgr::get_io_bytes() { std::unordered_map element_map = get_io_elements(); auto align = [](size_t byte) { size_t alignment = MemoryAllocator::kDefaultAlignment; @@ -799,24 +802,25 @@ std::unordered_map SmartMaskIoMgr::get_io_bytes() { align(element_map["kv_input_toks_ele"] * sizeof(int32_t))}, {"kv_input_pos_bytes", align(element_map["kv_input_pos_ele"] * sizeof(int32_t))}, - {"cache_in_bytes", align(element_map["cache_in_ele"] * sizeof(uint8_t))}, + {"cache_in_bytes", align(element_map["cache_in_ele"] * sizeof(KVType))}, {"cache_out_bytes", - align(element_map["cache_out_ele"] * sizeof(uint8_t))}, + align(element_map["cache_out_ele"] * sizeof(KVType))}, {"kv_attention_mask_bytes", - align(element_map["kv_attention_mask_ele"] * sizeof(uint16_t))}, + align(element_map["kv_attention_mask_ele"] * sizeof(IOType))}, {"kv_logits_bytes", - align(element_map["kv_logits_ele"] * sizeof(uint16_t))}, + align(element_map["kv_logits_ele"] * sizeof(IOType))}, {"prefill_input_toks_bytes", align(element_map["prefill_input_toks_ele"] * sizeof(int32_t))}, {"prefill_input_pos_bytes", align(element_map["prefill_input_pos_ele"] * sizeof(int32_t))}, {"prefill_attention_mask_bytes", - align(element_map["prefill_attention_mask_ele"] * sizeof(uint16_t))}, + align(element_map["prefill_attention_mask_ele"] * sizeof(IOType))}, {"prefill_logits_bytes", - align(element_map["prefill_logits_ele"] * sizeof(uint16_t))}}; + align(element_map["prefill_logits_ele"] * sizeof(IOType))}}; } -void SmartMaskIoMgr::IO::init_io_ptrs( +template +void SmartMaskIoMgr::IO::init_io_ptrs( void* shared_buffer_ptr, std::unordered_map& io_bytes_map) { shared_buffer_base = shared_buffer_ptr; @@ -842,11 +846,11 @@ void SmartMaskIoMgr::IO::init_io_ptrs( k_cache_ref[i].reserve(num_heads_); v_cache_ref[i].reserve(num_heads_); for (int j = 0; j < num_heads_; ++j) { - k_cache_ref[i][j] = reinterpret_cast(cur_ptr); + k_cache_ref[i][j] = reinterpret_cast(cur_ptr); io_pos_map[cur_ptr] = cur_pos; cur_ptr += single_head_size; cur_pos += single_head_size; - v_cache_ref[i][j] = reinterpret_cast(cur_ptr); + v_cache_ref[i][j] = reinterpret_cast(cur_ptr); io_pos_map[cur_ptr] = cur_pos; cur_ptr += single_head_size; cur_pos += single_head_size; @@ -854,17 +858,17 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } continue; } else if (key == "kv_attention_mask_bytes") { - kv_attention_mask = reinterpret_cast(cur_ptr); + kv_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "kv_logits_bytes") { - kv_logits = reinterpret_cast(cur_ptr); + kv_logits = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_toks_bytes") { prefill_input_toks = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_pos_bytes") { prefill_input_pos = reinterpret_cast(cur_ptr); } else if (key == "prefill_attention_mask_bytes") { - prefill_attention_mask = reinterpret_cast(cur_ptr); + prefill_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "prefill_logits_bytes") { - prefill_logits = reinterpret_cast(cur_ptr); + prefill_logits = reinterpret_cast(cur_ptr); } else { ET_LOG(Error, "Unknown pointer type: %s", key.c_str()); } @@ -875,7 +879,8 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } } -void SmartMaskIoMgr::IO::add_custom_mem_info( +template +void SmartMaskIoMgr::IO::add_custom_mem_info( void* ptr, size_t nbytes, executorch::aten::ScalarType scalar_type, @@ -892,7 +897,8 @@ void SmartMaskIoMgr::IO::add_custom_mem_info( QnnExecuTorchAddCustomMemTensorInfo(info); } -void SmartMaskIoMgr::init_io() { +template +void SmartMaskIoMgr::init_io() { std::unordered_map io_bytes_map = get_io_bytes(); switch (eval_mode_) { @@ -931,7 +937,8 @@ void SmartMaskIoMgr::init_io() { ptr->init_io_ptrs(shared_ptr, io_bytes_map); } -void SmartMaskIoMgr::reset_io( +template +void SmartMaskIoMgr::reset_io( const std::vector>& prefill_methods_meta, const std::vector< @@ -947,7 +954,8 @@ void SmartMaskIoMgr::reset_io( std::fill(ptr->kv_attention_mask, ptr->kv_attention_mask + kv_attn_size, 0); } -void SmartMaskIoMgr::prepare_kv_io( +template +void SmartMaskIoMgr::prepare_kv_io( const std::vector>& methods_meta) { for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( @@ -1020,7 +1028,7 @@ void SmartMaskIoMgr::prepare_kv_io( std::vector>& cache = (cache_group == 0 ? k_cache_in_[kv_forward_name_] : v_cache_in_[kv_forward_name_]); - uint8_t* cache_ptr = (cache_group == 0) + KVType* cache_ptr = (cache_group == 0) ? ptr->k_cache[layer + offset][head] : ptr->v_cache[layer + offset][head]; @@ -1074,7 +1082,7 @@ void SmartMaskIoMgr::prepare_kv_io( std::vector>& cache = (cache_group == 0 ? k_cache_out_[kv_forward_name_] : v_cache_out_[kv_forward_name_]); - uint8_t* cache_ptr = (cache_group == 0) + KVType* cache_ptr = (cache_group == 0) ? ptr->k_cache_out[layer + offset][head] : ptr->v_cache_out[layer + offset][head]; cache.emplace_back(std::make_unique( @@ -1097,7 +1105,8 @@ void SmartMaskIoMgr::prepare_kv_io( } } -void SmartMaskIoMgr::update_kv_io( +template +void SmartMaskIoMgr::update_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1115,16 +1124,16 @@ void SmartMaskIoMgr::update_kv_io( auto& v_cache_out = v_cache_out_[kv_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive for (int i = 0; i < v_cache_in.size(); ++i) { - uint8_t* ptr_in = v_cache_in[i]->mutable_data() + pos * head_dim_; - const uint8_t* ptr_out = v_cache_out[i]->data(); - memcpy(ptr_in, ptr_out, head_dim_ * sizeof(uint8_t)); + KVType* ptr_in = v_cache_in[i]->mutable_data() + pos * head_dim_; + const KVType* ptr_out = v_cache_out[i]->data(); + memcpy(ptr_in, ptr_out, head_dim_ * sizeof(KVType)); } auto& k_cache_in = k_cache_in_[kv_forward_name_]; auto& k_cache_out = k_cache_out_[kv_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - uint8_t* ptr_in = k_cache_in[i]->mutable_data() + pos; - const uint8_t* ptr_out = k_cache_out[i]->data(); + KVType* ptr_in = k_cache_in[i]->mutable_data() + pos; + const KVType* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = 0; j < head_dim_; ++j, offset += kv_cache_len_) { ptr_in[offset] = ptr_out[j]; @@ -1132,7 +1141,8 @@ void SmartMaskIoMgr::update_kv_io( } } -void SmartMaskIoMgr::prepare_prefill_io( +template +void SmartMaskIoMgr::prepare_prefill_io( const std::vector>& methods_meta) { for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( @@ -1226,7 +1236,7 @@ void SmartMaskIoMgr::prepare_prefill_io( std::vector>& cache = (cache_group == 0 ? k_cache_in_[prefill_forward_name_] : v_cache_in_[prefill_forward_name_]); - uint8_t* cache_ptr = (cache_group == 0) + KVType* cache_ptr = (cache_group == 0) ? ptr->k_cache[layer + offset][head] : ptr->v_cache[layer + offset][head]; @@ -1303,7 +1313,8 @@ void SmartMaskIoMgr::prepare_prefill_io( } } -void SmartMaskIoMgr::update_prefill_to_kv_io( +template +void SmartMaskIoMgr::update_prefill_to_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1322,18 +1333,18 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( auto& v_cache_in = v_cache_in_[kv_forward_name_]; auto& v_cache_out = v_cache_out_[prefill_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive - size_t copied_size = kv_cache_len_ * head_dim_ * sizeof(uint8_t); + size_t copied_size = kv_cache_len_ * head_dim_ * sizeof(KVType); for (int i = 0; i < v_cache_in.size(); ++i) { - uint8_t* ptr_in = v_cache_in[i]->mutable_data(); - const uint8_t* ptr_out = v_cache_out[i]->data(); + KVType* ptr_in = v_cache_in[i]->mutable_data(); + const KVType* ptr_out = v_cache_out[i]->data(); memcpy(ptr_in, ptr_out, copied_size); } auto& k_cache_in = k_cache_in_[kv_forward_name_]; auto& k_cache_out = k_cache_out_[prefill_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - uint8_t* ptr_in = k_cache_in[i]->mutable_data(); - const uint8_t* ptr_out = k_cache_out[i]->data(); + KVType* ptr_in = k_cache_in[i]->mutable_data(); + const KVType* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = 0; j < head_dim_; ++j, offset += kv_cache_len_) { for (size_t k = 0, k_stride = j * prefill_ar_len_; k < pos; k++) { @@ -1343,10 +1354,10 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( } } else { // Update K is enough, copy from last to prevent from overwriting values - size_t copied_size = pos * sizeof(uint8_t); + size_t copied_size = pos * sizeof(KVType); for (int l = 0; l < num_layers_; l++) { for (int h = 0; h < num_heads_; h++) { - uint8_t* k_cache = ptr->k_cache[l][h]; + KVType* k_cache = ptr->k_cache[l][h]; for (int hd = head_dim_ - 1; hd > -1; hd--) { memcpy( k_cache + (kv_cache_len_ * hd), @@ -1358,7 +1369,8 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( } } -void SmartMaskIoMgr::update_prefill_io( +template +void SmartMaskIoMgr::update_prefill_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1369,19 +1381,19 @@ void SmartMaskIoMgr::update_prefill_io( auto& v_cache_in = v_cache_in_[prefill_forward_name_]; auto& v_cache_out = v_cache_out_[prefill_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive - size_t copied_size = prefill_ar_len_ * head_dim_ * sizeof(uint8_t); + size_t copied_size = prefill_ar_len_ * head_dim_ * sizeof(KVType); for (int i = 0; i < v_cache_in.size(); ++i) { - uint8_t* ptr_in = - v_cache_in[i]->mutable_data() + pos * head_dim_; - const uint8_t* ptr_out = v_cache_out[i]->data(); + KVType* ptr_in = + v_cache_in[i]->mutable_data() + pos * head_dim_; + const KVType* ptr_out = v_cache_out[i]->data(); memcpy(ptr_in, ptr_out, copied_size); } auto& k_cache_in = k_cache_in_[prefill_forward_name_]; auto& k_cache_out = k_cache_out_[prefill_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - uint8_t* ptr_in = k_cache_in[i]->mutable_data(); - const uint8_t* ptr_out = k_cache_out[i]->data(); + KVType* ptr_in = k_cache_in[i]->mutable_data(); + const KVType* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = pos; j < head_dim_; ++j, offset += prefill_cache_len_) { for (size_t k = 0, k_stride = j * prefill_ar_len_; k < prefill_ar_len_; @@ -1393,7 +1405,8 @@ void SmartMaskIoMgr::update_prefill_io( } } -void SmartMaskIoMgr::fill_prefill_toks( +template +void SmartMaskIoMgr::fill_prefill_toks( int64_t start_pos, std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); @@ -1425,11 +1438,15 @@ void SmartMaskIoMgr::fill_prefill_toks( } } -void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { +template +void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); *ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); ptr->kv_attention_mask[kv_cache_len_] = 65535; } +template class SmartMaskIoMgr; +template class SmartMaskIoMgr; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index 0f10eef8ddc..7826d2f81df 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -196,6 +196,7 @@ class ShiftPointerIoMgr : public IoMgrBase { const bool use_int64_token_{false}; }; +template class SmartMaskIoMgr : public IoMgrBase { public: SmartMaskIoMgr( @@ -257,22 +258,22 @@ class SmartMaskIoMgr : public IoMgrBase { int64_t* kv_input_toks; int32_t* kv_input_pos; // layer -> head -> head_dim * seq_len - std::vector> k_cache; - std::vector> v_cache; + std::vector> k_cache; + std::vector> v_cache; // layer -> head -> head_dim - std::vector> k_cache_out; - std::vector> v_cache_out; + std::vector> k_cache_out; + std::vector> v_cache_out; // kv_ar_len_ * context_len_ - uint16_t* kv_attention_mask; + IOType* kv_attention_mask; // kv_ar_len_ * vocab_size - uint16_t* kv_logits; + IOType* kv_logits; // prefill_ar_len_ int64_t* prefill_input_toks; int32_t* prefill_input_pos; // prefill_ar_len_ * context_len_ - uint16_t* prefill_attention_mask; + IOType* prefill_attention_mask; // vocab_size * prefill_ar_len_ - uint16_t* prefill_logits; + IOType* prefill_logits; size_t num_layers_; size_t num_heads_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index dafc911a172..b4304a9ebad 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -50,7 +50,8 @@ Runner::Runner( const float temperature, const int eval_mode, const std::string& kv_updater, - const int num_iters) + const int num_iters, + const std::string& kv_type) : n_bos_(1), n_eos_(1), tokenizer_path_(tokenizer_path), @@ -60,7 +61,8 @@ Runner::Runner( temperature_(temperature), eval_mode_(static_cast(eval_mode)), kv_updater_(kv_updater), - num_iters_(num_iters) { + num_iters_(num_iters), + kv_type_(kv_type) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); @@ -70,6 +72,28 @@ Runner::Runner( ET_LOG(Info, "eval mode=%d", eval_mode_); } +Runner::Runner( + const std::vector& models_path, + const std::string& tokenizer_path, + const std::string& performance_output_path_, + const float logits_scale, + const int32_t logits_offset, + const float temperature, + const int eval_mode, + const std::string& kv_updater, + const int num_iters) + : Runner( + models_path, + tokenizer_path, + performance_output_path_, + logits_scale, + logits_offset, + temperature, + eval_mode, + kv_updater, + num_iters, + "uint8") { } + bool Runner::is_loaded() const { bool loaded = true; for (const std::shared_ptr& module : modules_) { @@ -139,21 +163,41 @@ Error Runner::load() { ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); if (kv_updater_ == "SmartMask") { - io_mgr_ = std::make_unique( - modules_, - context_len_, - prefill_ar_len_, - prefill_cache_len_, - kv_ar_len_, - kv_cache_len_, - vocab_size_, - num_layers, - head_dim, - num_heads, - eval_mode_, - prefill_forward_name_, - kv_forward_name_, - use_int64_token_); + if (kv_type_ == "uint8") { + io_mgr_ = std::make_unique>( + modules_, + context_len_, + prefill_ar_len_, + prefill_cache_len_, + kv_ar_len_, + kv_cache_len_, + vocab_size_, + num_layers, + head_dim, + num_heads, + eval_mode_, + prefill_forward_name_, + kv_forward_name_, + use_int64_token_); + } else if (kv_type_ == "float32") { + io_mgr_ = std::make_unique>( + modules_, + context_len_, + prefill_ar_len_, + prefill_cache_len_, + kv_ar_len_, + kv_cache_len_, + vocab_size_, + num_layers, + head_dim, + num_heads, + eval_mode_, + prefill_forward_name_, + kv_forward_name_, + use_int64_token_); + } else { + ET_LOG(Error, "Using an unknown kv type %s", kv_type_.c_str()); + } } else if (kv_updater_ == "ShiftPointer") { io_mgr_ = std::make_unique( modules_, @@ -246,21 +290,34 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) { } int32_t Runner::logitsToToken(const Tensor& logits_tensor, int64_t pos) { - static std::vector logits_f(vocab_size_); - const uint16_t* logits = logits_tensor.data_ptr(); - // Since the logits are for all tokens, get the last token probabilities - auto* logits_last = logits; - - // offset to the meaningful logit we want. - if (logits_tensor.sizes().data()[1] > 1) { - logits_last += pos * vocab_size_; - } + if (kv_type_ == "float32") { + // logits are float32 + float* logits = logits_tensor.data_ptr(); + // Since the logits are for all tokens, get the last token probabilities + auto* logits_last = logits; + + // offset to the meaningful logit we want. + if (logits_tensor.sizes().data()[1] > 1) { + logits_last += pos * vocab_size_; + } + return sampler_->sample(logits_last); + } else { + static std::vector logits_f(vocab_size_); + const uint16_t* logits = logits_tensor.data_ptr(); + // Since the logits are for all tokens, get the last token probabilities + auto* logits_last = logits; + + // offset to the meaningful logit we want. + if (logits_tensor.sizes().data()[1] > 1) { + logits_last += pos * vocab_size_; + } - // dequantize - for (int i = 0; i < vocab_size_; i++) { - logits_f[i] = (logits_last[i] - logits_offset_) * logits_scale_; + // dequantize + for (int i = 0; i < vocab_size_; i++) { + logits_f[i] = (logits_last[i] - logits_offset_) * logits_scale_; + } + return sampler_->sample(logits_f.data()); } - return sampler_->sample(logits_f.data()); } void Runner::run_model_step( @@ -528,7 +585,7 @@ void printReport( outfile << num_tok; outfile.close(); } else { - ET_CHECK_MSG(false, "Error saving the inference speed file"); + ET_LOG(Error, "Error saving the inference speed file"); } } @@ -542,6 +599,7 @@ std::string statsToJsonString(const Runner::Stats& stats) { << "\"inference_end_ms\":" << stats.inference_end_ms << "," << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"tokens_per_second\":" << (double)stats.num_generated_tokens * 1000 / (stats.inference_end_ms - stats.prompt_eval_end_ms) << "," << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; @@ -558,4 +616,8 @@ std::vector> Runner::get_methods_meta( } return methods_meta; } + +void Runner::stop() { + ET_LOG(Error, "Not implemented yet"); +} } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index e693bcd7077..2a16ad38bf0 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -19,13 +19,26 @@ #include #include +#include #include #include namespace example { -class Runner { +class Runner : public executorch::extension::llm::IRunner { public: + explicit Runner( + const std::vector& models_path, + const std::string& tokenizer_path, + const std::string& performance_output_path_, + const float logits_scale, + const int32_t logits_offset, + const float temperature, + const int eval_mode, + const std::string& kv_updater, + const int num_iters, + const std::string& kv_type); + explicit Runner( const std::vector& models_path, const std::string& tokenizer_path, @@ -63,15 +76,28 @@ class Runner { int64_t num_generated_tokens; }; - bool is_loaded() const; - executorch::runtime::Error load(); + bool is_loaded() const override; + executorch::runtime::Error load() override; executorch::runtime::Error generate( int32_t seq_len, const std::string& prompt, const std::string& system_prompt, std::function token_callback = {}, std::function stats_callback = {}); - void stop(); + executorch::runtime::Error generate( + const std::string& prompt, + const executorch::extension::llm::GenerationConfig& config, + std::function token_callback = {}, + std::function + stats_callback = {}) override { + // TODO: convert stats_callback + return generate( + config.seq_len, + prompt, + "", + token_callback); + } + void stop() override; std::vector> get_methods_meta(std::string& method_name); @@ -119,6 +145,7 @@ class Runner { LlamaVersion llama_version_; std::string kv_updater_; int num_iters_; + std::string kv_type_; }; } // namespace example diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 3a1fe79d8f5..54b9c94ea09 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -135,19 +135,28 @@ if(EXECUTORCH_JNI_CUSTOM_LIBRARY) ) endif() +set(ET_USE_QNN_OSS_RUNNER 1) + if(EXECUTORCH_BUILD_LLAMA_JNI) target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp) list(APPEND link_libraries llama_runner llava_runner) - target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1 ET_USE_QNN_OSS_RUNNER=${ET_USE_QNN_OSS_RUNNER}) add_subdirectory( ${EXECUTORCH_ROOT}/examples/models/llava/runner ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner ) - add_subdirectory( - ${EXECUTORCH_ROOT}/examples/models/llama/runner - ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner - ) + if(ET_USE_QNN_OSS_RUNNER) + add_subdirectory( + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/qualcomm/oss_scripts/llama/runner + ) + else() + add_subdirectory( + ${EXECUTORCH_ROOT}/examples/models/llama/runner + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner + ) + endif() if(NEURON_BUFFER_ALLOCATOR_LIB) target_sources( diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0e6731dfcd5..380a59bfcf1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -12,8 +12,13 @@ #include #include #include +#include +#if defined(ET_USE_QNN_OSS_RUNNER) +#include +#else #include +#endif #include #include #include @@ -165,6 +170,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { tokenizer_path->toStdString().c_str(), temperature); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { +#if defined(ET_USE_QNN_OSS_RUNNER) + // TODO: Replace with a more robust way to set the environment variable + setenv("QNN_OP_PACKAGE_PATHS", "/data/local/tmp/llama/libQnnTMANOpPackage.so:TMANOpPackageInterfaceProvider:HTP", 1); + runner_ = std::make_unique( + std::vector{model_path->toStdString().c_str()}, + tokenizer_path->toStdString().c_str(), + "", + 0.0012801217380911112f, // TODO: replace hardcoded values + 34183, + temperature, + 0, + "ShiftPointer", + 2); +#else if (data_path != nullptr) { runner_ = std::make_unique( model_path->toStdString().c_str(), @@ -175,6 +194,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str()); } +#endif #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( From 85931359e06476654785df8e2737862fa392114d Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Sun, 11 May 2025 10:22:47 +0000 Subject: [PATCH 02/11] QNN Backend: Update bitnet exporter bitnet.py && Enable HF tokenizer for OSS runner --- backends/qualcomm/scripts/build.sh | 4 + backends/qualcomm/utils/utils.py | 2 +- examples/qualcomm/CMakeLists.txt | 8 +- .../qualcomm/oss_scripts/bitnet/bitnet.py | 316 +++++++++--------- .../qualcomm/oss_scripts/llama/CMakeLists.txt | 1 + examples/qualcomm/oss_scripts/llama/llama.py | 55 +-- .../oss_scripts/llama/runner/runner.cpp | 181 +++++++++- .../oss_scripts/llama/runner/runner.h | 10 +- 8 files changed, 372 insertions(+), 205 deletions(-) diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c079dd41a2a..5a687eafb01 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -89,6 +89,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DANDROID_PLATFORM=android-30 \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ -B$BUILD_ROOT cmake --build $BUILD_ROOT -j$BUILD_JOB_NUMBER --target install @@ -105,6 +106,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER @@ -131,6 +133,7 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ @@ -153,6 +156,7 @@ if [ "$BUILD_X86_64" = true ]; then -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -DSUPPORT_REGEX_LOOKAHEAD=ON \ -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index c7714f5a279..0e0eaf0a9e9 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -28,7 +28,7 @@ tman_linear, tman_bitnet_linear, ) -from executorch.backends.qualcomm.builders.utils import unpack_weights +from executorch.backends.qualcomm.builders.utils import unpack_weights, unpack_gptqv2 from executorch.backends.qualcomm.partition.qnn_partitioner import ( generate_qnn_executorch_option, get_skip_decomp_table, diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 4f338a23044..016260d4aa1 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -78,12 +78,8 @@ set(ABSL_PROPAGATE_CXX_STD ON) set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/tokenizers/third-party/abseil-cpp - ${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp -) -add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/tokenizers/third-party/re2 - ${CMAKE_CURRENT_BINARY_DIR}/re2 + ${EXECUTORCH_ROOT}/extension/llm/tokenizers + ${CMAKE_CURRENT_BINARY_DIR}/tokenizers ) set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) diff --git a/examples/qualcomm/oss_scripts/bitnet/bitnet.py b/examples/qualcomm/oss_scripts/bitnet/bitnet.py index 5a5694d2b2a..15a4419c817 100644 --- a/examples/qualcomm/oss_scripts/bitnet/bitnet.py +++ b/examples/qualcomm/oss_scripts/bitnet/bitnet.py @@ -15,14 +15,20 @@ import subprocess import sys import time -from collections import OrderedDict from functools import partial from multiprocessing.connection import Client import torch -from executorch.backends.qualcomm._passes.constant_i64_to_i32 import ConstantI64toI32 +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO +from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) -from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.builders.utils import is_graph_output from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_linear_16a8w_in_affine_layer, @@ -38,11 +44,11 @@ option_to_flatbuffer, ) from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, - QCOM_QUANTIZED_IO, + QCOM_QUANT_ATTRS_MAP, ) from executorch.backends.qualcomm.utils.utils import ( - capture_program, convert_linear_to_conv2d, convert_linear_to_bitlinear, convert_bitlinear_to_linear, @@ -50,8 +56,8 @@ generate_htp_compiler_spec, generate_multi_graph_program, generate_qnn_executorch_compiler_spec, - get_capture_program_passes, get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, update_spill_fill_size, ) @@ -59,7 +65,6 @@ from executorch.examples.models.llama.source_transformation.quantize import ( get_quant_embedding_transform, ) -from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken from executorch.examples.qualcomm.oss_scripts.bitnet.model.static_bitnet import ( BitNetForCausalLM, BitNetConfig, @@ -71,25 +76,19 @@ setup_common_args_and_variables, SimpleADB, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager -from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.extension.llm.custom_ops import model_sharding from executorch.extension.llm.export.builder import DType -from executorch.extension.llm.tokenizer.tokenizer import ( - Tokenizer as SentencePieceTokenizer, -) -from executorch.extension.llm.tokenizer.hf_tokenizer import HuggingFaceTokenizer -from executorch.extension.llm.tokenizer.utils import get_tokenizer +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer, HuggingFaceTokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from torch.ao.quantization.observer import MinMaxObserver from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from safetensors.torch import load_file - sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -150,7 +149,7 @@ def _kv_calibrate( # Llama2 tokenizer has no special tokens if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) - elif isinstance(tokenizer, Tiktoken): + elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( user_prompts, bos=True, eos=False, allowed_special="all" ) @@ -224,10 +223,12 @@ def _prefill_calibrate( # Llama2 tokenizer has no special tokens if isinstance(tokenizer, SentencePieceTokenizer): token_list = tokenizer.encode(user_prompts, bos=True, eos=False) - elif isinstance(tokenizer, Tiktoken): + elif isinstance(tokenizer, TiktokenTokenizer): token_list = tokenizer.encode( user_prompts, bos=True, eos=False, allowed_special="all" ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) else: raise RuntimeError("Unkown tokenizer") @@ -293,18 +294,13 @@ def calibrate( raise RuntimeError("Get wrong inputs") -def permute(weights: torch.Tensor, n_head: int, n_head_kv: int | None): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - - class SingleLlama: def __init__(self, llama_model, pte_filename) -> None: super().__init__() self.llama_model = llama_model + self.passes_job = get_capture_program_passes() + self.dep_table = get_passes_dependency_for_capture_program() + self.quant_attrs = None self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False @@ -318,8 +314,16 @@ def __init__(self, llama_model, pte_filename) -> None: tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) self.inputs = (tokens, atten_mask) self.llama_graph_module = llama_model + self.io_shape = { + # logit output + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_vocab_size"], + ), + } - def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): + def _tag_ios(self, node, fixed_point_type): if not self.has_quant_io: return @@ -332,14 +336,6 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), } - io_shape = { - # logit output - ( - self.llama_meta["get_max_batch_size"], - self.llama_meta["get_ar_len"], - self.llama_meta["get_vocab_size"], - ), - } atten_mask_shape = { ( @@ -356,37 +352,35 @@ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type): freq_op = { exir_ops.edge.aten.select.int, } - - for n in gm.graph.nodes: - if n.op == "placeholder": - if ( - len(users := list(n.users)) == 1 - and users[0].meta["val"].size()[-2:] in kv_cache_shape - ): - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] - elif n.meta["val"].size() in io_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - elif n.meta["val"].size() in atten_mask_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - elif n.op == "output": - for a in n.args[0]: - if a.meta["val"].size()[-2:] in kv_cache_shape: - a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"] - elif a.meta["val"].size() in io_shape: - a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - quant_attrs = a.meta["quant_attrs"] - - # Tag sharding io - if exir_ops.edge.llama.fallback.default in [ - u.target for u in list(n.users.keys()) - ] + [n.target]: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - - # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding - if n.target in freq_op and n.meta["val"].size() in freq_shape: - n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"] - - return quant_attrs + quant_io_type = None + + if node.op == "placeholder": + if ( + len(users := list(node.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + elif node.meta["val"].size() in atten_mask_shape: + quant_io_type = fixed_point_type["io_type"] + if is_graph_output(node): + if node.meta["val"].size()[-2:] in kv_cache_shape: + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + quant_io_type = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if node.target in freq_op and node.meta["val"].size() in freq_shape: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): self.quant_dtype = quant_dtype @@ -405,6 +399,14 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): fx_graph_module = torch.export.export( self.llama_graph_module, self.inputs, strict=True ).module() + + if QuantDtype == QuantDtype.use_16a4w_block: + conv_nodes = [ + n for n in fx_graph_module.graph.nodes if "conv" in n.name + ] + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + quantizer.set_block_size_map(block_size_map) + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") @@ -424,11 +426,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): def lowering_modules( self, work_space, - fixed_point_type, use_fp16=False, soc_model=QcomChipset.SM8650, num_sharding=1, - passes_job=OrderedDict(), shared_buffer=False, verbose=False, ): @@ -454,32 +454,22 @@ def lowering_modules( shared_buffer=shared_buffer, ) skip_node_op_set = {"llama.fallback.default"} - partitioner = QnnPartitioner( - compiler_specs, skip_node_op_set=skip_node_op_set - ) - edge_prog = capture_program( + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( self.llama_graph_module, self.inputs, - passes_job, + compiler_specs, + constant_methods=self.llama_meta, + dep_table=self.dep_table, + passes_job=self.passes_job, + skip_node_op_set=skip_node_op_set, ) - if num_sharding > 1: - model_sharding.split_graph( - edge_prog.exported_program, - self.llama_meta["get_n_layers"], - shares=num_sharding, - ) + for n in edge_prog_mgr.exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in self.io_shape: + self.quant_attrs = output_encoding - self.quant_attrs = self._tag_ios( - edge_prog.exported_program.graph_module, - fixed_point_type=fixed_point_type, - ) - edge_prog_mgr = EdgeProgramManager( - edge_programs={"forward": edge_prog.exported_program}, - constant_methods=self.llama_meta, - compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) - edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) if num_sharding > 1: update_spill_fill_size(edge_prog_mgr.exported_program()) @@ -557,7 +547,6 @@ def compile(args, pte_filename, tokenizer): for llama_instance in llama_instance_list: for layer in llama_instance.model.layers: if args.use_tman: - # TODO layer.self_attn.prepare_tman() else: convert_bitlinear_to_linear(layer.self_attn) @@ -571,18 +560,18 @@ def compile(args, pte_filename, tokenizer): fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": fixed_point_type["io_type"] = torch.uint8 - elif args.ptq == "16a4w": + elif args.ptq in ("16a4w", "16a4w_block"): fixed_point_type["io_type"] = torch.uint16 else: assert args.ptq in [ "8a8w", "16a4w", - ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + "16a4w_block", + ], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block." quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") assert args.tokenizer_model is not None, "Need tokenizer model for calibration" - passes_job = get_capture_program_passes() if args.dtype_override is not None: dtype_override = DType[args.dtype_override] for i in range(len(llama_instance_list)): @@ -595,14 +584,14 @@ def compile(args, pte_filename, tokenizer): llama_instance_list[i] = get_quant_embedding_transform(args)( llama_instance_list[i] ) - passes_job[ConstantI64toI32][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ - "skip_node" - ] = {"tokens"} llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) - print(llama_instance_list[i]) llama_instance_list[i] = SingleLlama( llama_instance_list[i].eval(), pte_filename ) + if args.embedding_quantize: + llama_instance_list[i].passes_job[I64toI32][ + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY + ]["skip_node"] = {"tokens"} if args.ptq: start_quantize_ts = time.time() @@ -611,44 +600,54 @@ def compile(args, pte_filename, tokenizer): custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) - if args.ptq != None: - kv_quant_attrs = {} - for i, llama_instance in enumerate(llama_instance_list): - llama_instance.quantize( - quant_dtype=quant_dtype, - args=args, - tokenizer=tokenizer, - custom_annotations=custom_annotations, + kv_quant_attrs = {} + for i, llama_instance in enumerate(llama_instance_list): + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + tokenizer=tokenizer, + custom_annotations=custom_annotations, + ) + # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode == "hybrid": + output_indices = 0 + for node in llama_instance.llama_graph_module.graph.nodes: + if node.op == "output": + for output in node.args[0]: + kv_quant_attrs[output_indices] = output.args[1:] + output_indices += 1 + break + custom_annotations = custom_annotations + ( + partial( + annotate_prefill_kv_output, + kv_quant_attrs=kv_quant_attrs, + ), ) - # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later - if i == 0 and args.model_mode == "hybrid": - output_indices = 0 - for node in llama_instance.llama_graph_module.graph.nodes: - if node.op == "output": - for output in node.args[0]: - kv_quant_attrs[output_indices] = output.args[1:] - output_indices += 1 - break - custom_annotations = custom_annotations + ( - partial( - annotate_prefill_kv_output, - kv_quant_attrs=kv_quant_attrs, - ), - ) + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") start_lowering_ts = time.time() quant_attrs = None + if args.num_sharding > 1: + for llama_instance in llama_instance_list: + SplitGraph, setting = model_sharding.get_split_graph_pass( + llama_instance.llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + llama_instance.passes_job[SplitGraph] = setting + llama_instance.dep_table[SplitGraph] = [FoldQDQ] + llama_instance.dep_table[TagQuantIO] = [SplitGraph] if args.model_mode in ["kv"]: llama_instance_list[0].lowering_modules( args.artifact, - fixed_point_type, use_fp16=use_fp16, soc_model=get_soc_to_chipset_map()[args.model], num_sharding=args.num_sharding, - passes_job=passes_job, shared_buffer=args.shared_buffer, ) quant_attrs = llama_instance_list[0].get_quant_attrs() @@ -656,30 +655,6 @@ def compile(args, pte_filename, tokenizer): sample_inputs_list = [ llama_instace.inputs for llama_instace in llama_instance_list ] - edge_progs = [ - capture_program( - llama_instance.llama_graph_module, - sample_input, - passes_job=passes_job, - ) - for llama_instance, sample_input in zip( - llama_instance_list, sample_inputs_list - ) - ] - - if args.num_sharding > 1: - for i in range(len(llama_instance_list)): - model_sharding.split_graph( - edge_progs[i].exported_program, - llama_instance_list[i].llama_meta["get_n_layers"], - shares=args.num_sharding, - ) - - for i in range(len(llama_instance_list)): - quant_attrs = llama_instance_list[i]._tag_ios( - edge_progs[i].exported_program.graph_module, - fixed_point_type, - ) backend_options = generate_htp_compiler_spec( use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 ) @@ -696,15 +671,29 @@ def compile(args, pte_filename, tokenizer): for graph_name in graph_names ] skip_node_op_set = {"llama.fallback.default"} - exported_programs = [ - to_backend( - edge_prog.exported_program, - QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set), + edge_prog_mgrs = [ + to_edge_transform_and_lower_to_qnn( + llama_instance.llama_graph_module, + sample_input, + compile_spec, + dep_table=llama_instance.dep_table, + passes_job=llama_instance.passes_job, + skip_node_op_set=skip_node_op_set, + ) + for llama_instance, sample_input, compile_spec in zip( + llama_instance_list, sample_inputs_list, compiler_specs ) - for i, edge_prog in enumerate(edge_progs) ] + for n in edge_prog_mgrs[0].exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in llama_instance_list[0].io_shape: + quant_attrs = output_encoding + if args.num_sharding > 1: - max_sf_size = update_spill_fill_size(exported_programs) + max_sf_size = update_spill_fill_size( + [edge_prog_mgr.exported_program() for edge_prog_mgr in edge_prog_mgrs] + ) qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( max_sf_size @@ -712,8 +701,8 @@ def compile(args, pte_filename, tokenizer): compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) if args.verbose: - for exported_program in exported_programs: - print_delegation_info(exported_program.graph_module) + for edge_prog_mgr in edge_prog_mgrs: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) executorch_config = ExecutorchBackendConfig( # For shared buffer, user must pass the memory address @@ -733,8 +722,8 @@ def compile(args, pte_filename, tokenizer): call_delegate_node_name_dict = {name: [] for name in graph_names} outputs_dict = {name: [] for name in graph_names} input_nodes_dict = {name: [] for name in graph_names} - for prog, graph_name in zip(exported_programs, graph_names): - for node in prog.graph_module.graph.nodes: + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + for node in prog.exported_program().graph_module.graph.nodes: if ( node.op == "call_function" and "executorch_call_delegate" in node.name @@ -765,13 +754,15 @@ def compile(args, pte_filename, tokenizer): outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) for num in range(args.num_sharding - 1, -1, -1): processed_bytes = [] - for prog, graph_name in zip(exported_programs, graph_names): + for prog, graph_name in zip(edge_prog_mgrs, graph_names): processed_bytes.append( - getattr(prog.graph_module, f"lowered_module_{num}").processed_bytes + getattr( + prog.exported_program().graph_module, f"lowered_module_{num}" + ).processed_bytes ) call_delegate_node = [ list(node.users.keys())[0] - for node in prog.graph_module.graph.nodes + for node in prog.exported_program().graph_module.graph.nodes if node.op == "get_attr" and node.name == f"lowered_module_{num}" ] input_nodes_dict[graph_name] = [ @@ -853,6 +844,7 @@ def post_process(): ) runner_cmd = "" + performance_output_path = "outputs/inference_speed.txt" if args.enable_x86_64: # x86 emulator is intended for CI and not performance. Check only the first few tokens. seq_len = min(seq_len, 16) @@ -872,6 +864,7 @@ def post_process(): f"--model_path {pte_path}", f"--seq_len {seq_len}", f"--output_path {args.artifact}/outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", f"--kv_updater ShiftPointer", runner_args, ] @@ -892,6 +885,7 @@ def post_process(): f"--model_path {pte_filename}.pte", f"--seq_len {seq_len}", "--output_path outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", runner_args, ] @@ -915,7 +909,7 @@ def post_process(): adb.pull(output_path=args.artifact, callback=post_process) if args.ip and args.port != -1: inference_speed = 0 - with open(f"{args.artifact}/outputs/inference_speed.txt", "r") as f: + with open(f"{args.artifact}/{performance_output_path}", "r") as f: inference_speed = float(f.read()) pte_size = os.path.getsize(pte_path) @@ -947,7 +941,7 @@ def _build_parser(): parser.add_argument( "-P", "--ptq", - help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", type=str, ) @@ -1097,7 +1091,7 @@ def export_llama(args) -> None: runtime_tokenizer_path = args.tokenizer_bin elif args.llama_model == "llama3_2": assert isinstance( - tokenizer, Tiktoken + tokenizer, TiktokenTokenizer ), f"Wrong tokenizer provided for llama3_2." runtime_tokenizer_path = args.tokenizer_model elif args.llama_model == "bitnet": diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 4d4f1c2e39d..885e283faed 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -59,6 +59,7 @@ target_link_libraries( custom_ops quantized_ops_lib quantized_kernels + tokenizers ) target_compile_options( qnn_llama_runner PUBLIC ${_common_compile_options} diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index b75d9150de8..d6e883ee788 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -82,7 +82,7 @@ from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.extension.llm.custom_ops import model_sharding from executorch.extension.llm.export.builder import DType -from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer, HuggingFaceTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer from torch.ao.quantization.observer import MinMaxObserver @@ -152,6 +152,8 @@ def _kv_calibrate( token_list = tokenizer.encode( user_prompts, bos=True, eos=False, allowed_special="all" ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) else: raise RuntimeError("Unkown tokenizer") @@ -587,13 +589,14 @@ def permute(w, heads): n_kv_heads = llama_instance_list[0].n_kv_heads n_layers = llama_instance_list[0].n_layers - for layer_i in range(n_layers): - state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads - ) - state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( - state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads - ) + if not args.gptq_dir: + for layer_i in range(n_layers): + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + ) + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + ) for llama_instance in llama_instance_list: llama_instance.load_state_dict( @@ -604,22 +607,6 @@ def permute(w, heads): end_load_ts = time.time() logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") - for llama_instance in llama_instance_list: - for layer in llama_instance.layers: - if args.gptq_dir: - # TODO: optimize the performance when needed - if args.use_tman: - if getattr(layer.attention, "prepare_tman", None): - layer.attention.prepare_tman(do_permute=False, use_sha=False) - convert_qlinear_to_tman_linear(layer.feed_forward) - else: - convert_qlinear_to_linear(layer.attention) - if getattr(layer.attention, "prepare_sha", None): - layer.attention.prepare_sha() - convert_qlinear_to_linear(layer.feed_forward) - if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): - layer.feed_forward.prepare_feedfoward_conv() - use_fp16 = True fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} if args.ptq: @@ -646,6 +633,23 @@ def permute(w, heads): dtype_override.to_torch_dtype() ) + for llama_instance in llama_instance_list: + for layer in llama_instance.layers: + if args.gptq_dir: + # TODO: optimize the performance when needed + if args.use_tman: + if getattr(layer.attention, "prepare_tman", None): + layer.attention.prepare_tman(do_permute=False, use_sha=False) + convert_qlinear_to_tman_linear(layer.feed_forward) + else: + convert_qlinear_to_linear(layer.attention) + if getattr(layer.attention, "prepare_sha", None): + layer.attention.prepare_sha() + convert_qlinear_to_linear(layer.feed_forward) + if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): + layer.feed_forward.prepare_feedfoward_conv() + + for i in range(len(llama_instance_list)): if args.embedding_quantize: llama_instance_list[i] = get_quant_embedding_transform(args)( @@ -1175,8 +1179,11 @@ def export_llama(args) -> None: elif args.llama_model == "llama3_2": assert isinstance( tokenizer, TiktokenTokenizer + ) or isinstance( + tokenizer, HuggingFaceTokenizer ), f"Wrong tokenizer provided for llama3_2." runtime_tokenizer_path = args.tokenizer_model + # args.prompt = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n".format(args.prompt) else: raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index b4304a9ebad..09252f83b96 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -247,10 +248,21 @@ Error Runner::load() { tokenizer_ = std::make_unique(); err = tokenizer_->load(tokenizer_path_); llama_version_ = LlamaVersion::kLlama2; - ET_CHECK_MSG( - err == tokenizers::Error::Ok, - "failed to load tokenizer %s", - tokenizer_path_.c_str()); + if (err != tokenizers::Error::Ok) { + ET_LOG( + Info, + "Failed to load %s as a llama2.c tokenizer artifact", + tokenizer_path_.c_str()); + tokenizer_.reset(); + tokenizer_ = std::make_unique(); + err = tokenizer_->load(tokenizer_path_); + ET_CHECK_MSG( + err == tokenizers::Error::Ok, + "failed to load tokenizer %s", + tokenizer_path_.c_str()); + eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); + llama_version_ = LlamaVersion::kLlama3; + } } else { eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); llama_version_ = LlamaVersion::kLlama3; @@ -512,6 +524,167 @@ Error Runner::generate( return Error::Ok; } +Error Runner::generate( + const std::string& prompt, + const executorch::extension::llm::GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + std::unordered_map>> + input_tensors, output_tensors; + std::unordered_map>> inputs; + executorch::extension::llm::Stats stats; + stats.model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + for (auto method_name : method_names_) { + for (int i = 0; i < modules_.size(); ++i) { + input_tensors[method_name].emplace_back( + io_mgr_->get_input_tensors(i, method_name)); + output_tensors[method_name].emplace_back( + io_mgr_->get_output_tensors(i, method_name)); + for (size_t j = 0; j < output_tensors[method_name][i].size(); ++j) { + ET_CHECK_MSG( + modules_[i]->set_output( + method_name, output_tensors[method_name][i][j], j) == + Error::Ok, + "failed to set output tensor for module %d's %zu'th output", + i, + j); + } + inputs[method_name].emplace_back(std::vector( + begin(input_tensors[method_name][i]), + end(input_tensors[method_name][i]))); + } + } + stats.model_load_end_ms = time_in_ms(); + stats.inference_start_ms = time_in_ms(); + + ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); + + int32_t seq_len = config.seq_len; + seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; + tokenizers::Result> encode_res = + tokenizer_->encode(prompt, n_bos_, 0); + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "failed to encode prompt %s", prompt.c_str()); + + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + if (num_prompt_tokens >= seq_len) { + // Leave at least half of the context length for token generation + num_prompt_tokens = seq_len / 2; + prompt_tokens = std::vector( + prompt_tokens.end() - num_prompt_tokens, prompt_tokens.end()); + } + ET_CHECK_MSG( + num_prompt_tokens < seq_len, + "sequence length exceeded - please increase the seq_len value"); + + int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; + if (config.echo && token_callback) { + token_callback(prompt); + } + auto prefill_execute = [&](const std::string& method_name) { + int num_iters = 1 + ((num_prompt_tokens - 1) / prefill_ar_len_); + ET_LOG( + Info, + "Prompt Processor: total %d tokens (AR-%d * %d iters)", + num_prompt_tokens, + prefill_ar_len_, + num_iters); + + for (int i = 0; i < num_iters; i++) { + io_mgr_->fill_prefill_toks(pos, prompt_tokens); + run_model_step(method_name, inputs[method_name]); + io_mgr_->update_prefill_io(cur_token, pos, output_tensors[method_name]); + pos += prefill_ar_len_; + } + Tensor& logits_tensor = output_tensors[method_name].back()[0]; + prev_token = prompt_tokens[num_prompt_tokens - 1]; + long sample_start_time_ms = time_in_ms(); + cur_token = logitsToToken( + logits_tensor, + (num_prompt_tokens + prefill_ar_len_ - 1) % prefill_ar_len_); + stats.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; + + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + if (token_callback) { + token_callback(piece_res.get().c_str()); + } + + pos = num_prompt_tokens; + stats.first_token_ms = time_in_ms(); + stats.prompt_eval_end_ms = time_in_ms(); + }; + + auto kv_execute = [&](const std::string& method_name) { + io_mgr_->fill_kv_tok_mask(pos, cur_token); + while (pos < seq_len - 1) { + // inference + run_model_step(method_name, inputs[method_name]); + Tensor& logits_tensor = output_tensors[method_name].back()[0]; + + // hybrid mode will check these stats_ at prefill(prefill) + if (eval_mode_ == EvalMode::kKVCached) { + if (pos == num_prompt_tokens) { + stats.first_token_ms = time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats.prompt_eval_end_ms = time_in_ms(); + } + } + prev_token = cur_token; + long sample_start_time_ms = time_in_ms(); + cur_token = logitsToToken(logits_tensor, pos); + stats.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; + + if (pos < num_prompt_tokens - 1) { + cur_token = prompt_tokens[pos + 1]; + } + io_mgr_->update_kv_io(cur_token, ++pos, output_tensors[method_name]); + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + + if (token_callback && pos >= num_prompt_tokens) { + token_callback(piece_res.get().c_str()); + } + + if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) { + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + } + }; + + switch (eval_mode_) { + case EvalMode::kKVCached: + kv_execute(kv_forward_name_); + break; + case EvalMode::kHybrid: + prefill_execute(prefill_forward_name_); + io_mgr_->update_prefill_to_kv_io( + cur_token, pos, output_tensors[kv_forward_name_]); + kv_execute(kv_forward_name_); + break; + default: + ET_CHECK_MSG(false, "Unsupported eval mode"); + break; + } + stats.inference_end_ms = time_in_ms(); + if (pos == seq_len) { + ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); + } + + stats.num_prompt_tokens = num_prompt_tokens; + stats.num_generated_tokens = pos - num_prompt_tokens; + if (stats_callback) { + stats_callback(stats); + } + io_mgr_->reset_io( + get_methods_meta(prefill_forward_name_), + get_methods_meta(kv_forward_name_)); + return Error::Ok; +} + namespace { void printReport( const Runner::Stats& stats, diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 2a16ad38bf0..c2df018870e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -88,15 +88,7 @@ class Runner : public executorch::extension::llm::IRunner { const std::string& prompt, const executorch::extension::llm::GenerationConfig& config, std::function token_callback = {}, - std::function - stats_callback = {}) override { - // TODO: convert stats_callback - return generate( - config.seq_len, - prompt, - "", - token_callback); - } + std::function stats_callback = {}); void stop() override; std::vector> get_methods_meta(std::string& method_name); From bc8aea9fe4428f454a4ce1f5a3fd75aa8b74fe5a Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Sun, 11 May 2025 10:24:45 +0000 Subject: [PATCH 03/11] LlamaDemo: automatically download model files from presets --- .../app/src/main/AndroidManifest.xml | 2 + .../executorchllamademo/SettingsActivity.java | 178 +++++++++++++++--- .../app/src/main/res/layout/activity_main.xml | 2 +- extension/android/jni/jni_layer_llama.cpp | 28 ++- 4 files changed, 176 insertions(+), 34 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml index 7096a7d4e76..d4bbe5ac063 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml @@ -11,6 +11,8 @@ + + diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java index 290cbec413e..fa0066188b6 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -19,6 +19,8 @@ import android.widget.EditText; import android.widget.ImageButton; import android.widget.TextView; +import android.widget.Toast; +import android.app.ProgressDialog; import androidx.appcompat.app.AppCompatActivity; import androidx.core.content.ContextCompat; import androidx.core.graphics.Insets; @@ -29,6 +31,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.net.HttpURLConnection; +import java.io.FileOutputStream; +import java.io.FileInputStream; +import java.io.InputStream; +import java.net.URL; public class SettingsActivity extends AppCompatActivity { @@ -93,10 +100,7 @@ private void setupSettings() { view -> { setupModelSelectorDialog(); }); - tokenizerImageButton.setOnClickListener( - view -> { - setupTokenizerSelectorDialog(); - }); + tokenizerImageButton.setEnabled(false); modelTypeImageButton.setOnClickListener( view -> { setupModelTypeSelectorDialog(); @@ -324,19 +328,153 @@ private void setupBackendSelectorDialog() { backendTypeBuilder.create().show(); } + private static class ModelInfo { + String modelName; + String tokenizerUrl; + String modelUrl; + String quantAttrsUrl; + + ModelInfo(String modelName, String tokenizerUrl, String modelUrl, String quantAttrsUrl) { + this.modelName = modelName; + this.tokenizerUrl = tokenizerUrl; + this.modelUrl = modelUrl; + this.quantAttrsUrl = quantAttrsUrl; + } + } + + // Construct the model info array + private final ModelInfo[] modelInfoArray = new ModelInfo[] { + new ModelInfo( + "bitnet-b1.58-2B-4T", + "https://huggingface.co/JY-W/test_model/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn.pte?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn_quant_attrs.txt?download=true" + ), + }; + + private final String mOpPackageUrl = "https://huggingface.co/JY-W/test_model/resolve/main/libQnnTMANOpPackage.so?download=true"; + + private void downloadFileFromUrl(String fileUrl, String fileName, boolean overwrite) { + ProgressDialog progressDialog = new ProgressDialog(this); + progressDialog.setTitle("Downloading " + fileName + "..."); + progressDialog.setMessage("Please wait..."); + progressDialog.setProgressStyle(ProgressDialog.STYLE_HORIZONTAL); + progressDialog.setIndeterminate(false); + progressDialog.setMax(100); + progressDialog.setProgress(0); + progressDialog.setCancelable(false); + progressDialog.show(); + + new Thread(() -> { + try { + File outputDir = getExternalFilesDir(null); + File outputFile = new File(outputDir, fileName); + if (!outputFile.exists() || overwrite) { + URL url = new URL(fileUrl); + HttpURLConnection connection = (HttpURLConnection) url.openConnection(); + connection.connect(); + + if (connection.getResponseCode() != HttpURLConnection.HTTP_OK) { + runOnUiThread(() -> { + progressDialog.dismiss(); + Toast.makeText(this, "Failed to download file", Toast.LENGTH_SHORT).show(); + }); + return; + } + + int fileLength = connection.getContentLength(); + InputStream input = connection.getInputStream(); + + try (FileOutputStream output = new FileOutputStream(outputFile)) { + byte[] buffer = new byte[4096]; + int bytesRead; + long totalBytesRead = 0; + + while ((bytesRead = input.read(buffer)) != -1) { + totalBytesRead += bytesRead; + output.write(buffer, 0, bytesRead); + + // Update progress + int progress = (int) (totalBytesRead * 100 / fileLength); + runOnUiThread(() -> progressDialog.setProgress(progress)); + } + } + + connection.disconnect(); + } + + runOnUiThread(() -> { + progressDialog.dismiss(); + mLoadModelButton.setEnabled(true); + Toast.makeText(this, "File downloaded to " + outputFile.getAbsolutePath(), Toast.LENGTH_SHORT).show(); + }); + } catch (Exception e) { + runOnUiThread(() -> { + progressDialog.dismiss(); + Toast.makeText(this, "Error: " + e.getMessage(), Toast.LENGTH_SHORT).show(); + }); + } + }).start(); + } + + private void downloadModel(ModelInfo modelInfo) { + String modelFileName = modelInfo.modelName + ".pte"; + String tokenizerFileName = modelInfo.modelName + "_tokenizer.json"; + String quantAttrsFileName = modelInfo.modelName + "_quant_attrs.txt"; + String opPackageFileName = "libQnnTMANOpPackage.so"; + + File modelFile = new File(getExternalFilesDir(null), modelFileName); + File tokenizerFile = new File(getExternalFilesDir(null), tokenizerFileName); + File quantAttrsFile = new File(getExternalFilesDir(null), quantAttrsFileName); + File opPackageFile = new File(getExternalFilesDir(null), opPackageFileName); + + if (modelFile.exists() || tokenizerFile.exists() || quantAttrsFile.exists() || opPackageFile.exists()) { + runOnUiThread(() -> { + new AlertDialog.Builder(this) + .setTitle("Overwrite Existing Files") + .setMessage("Some files for this model already exist. Do you want to overwrite them?") + .setPositiveButton("Yes", (dialog, which) -> { + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, true); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, true); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, true); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, true); + }) + .setNegativeButton("No", (dialog, which) -> { + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, false); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, false); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, false); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, false); + }) + .show(); + }); + } else { + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, true); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, true); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, true); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, true); + } + mModelFilePath = modelFile.getAbsolutePath(); + mModelTextView.setText(getFilenameFromPath(mModelFilePath)); + mTokenizerFilePath = tokenizerFile.getAbsolutePath(); + mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); + } + private void setupModelSelectorDialog() { - String[] pteFiles = listLocalFile("/data/local/tmp/llama/", new String[] {".pte"}); + // set a map from model name to url AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this); - modelPathBuilder.setTitle("Select model path"); + modelPathBuilder.setTitle("Select model"); + + String[] modelNames = Arrays.stream(modelInfoArray) + .map(modelInfo -> modelInfo.modelName) + .toArray(String[]::new); modelPathBuilder.setSingleChoiceItems( - pteFiles, + modelNames, -1, (dialog, item) -> { - mModelFilePath = pteFiles[item]; - mModelTextView.setText(getFilenameFromPath(mModelFilePath)); - mLoadModelButton.setEnabled(true); - dialog.dismiss(); + ModelInfo selectedModel = modelInfoArray[item]; + downloadModel(selectedModel); + dialog.dismiss(); }); modelPathBuilder.create().show(); @@ -384,24 +522,6 @@ private void setupModelTypeSelectorDialog() { modelTypeBuilder.create().show(); } - private void setupTokenizerSelectorDialog() { - String[] tokenizerFiles = - listLocalFile("/data/local/tmp/llama/", new String[] {".bin", ".json", ".model"}); - AlertDialog.Builder tokenizerPathBuilder = new AlertDialog.Builder(this); - tokenizerPathBuilder.setTitle("Select tokenizer path"); - tokenizerPathBuilder.setSingleChoiceItems( - tokenizerFiles, - -1, - (dialog, item) -> { - mTokenizerFilePath = tokenizerFiles[item]; - mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); - mLoadModelButton.setEnabled(true); - dialog.dismiss(); - }); - - tokenizerPathBuilder.create().show(); - } - private String getFilenameFromPath(String uriFilePath) { String[] segments = uriFilePath.split("/"); if (segments.length > 0) { diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml index 84172982c54..52bf533521a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/res/layout/activity_main.xml @@ -22,7 +22,7 @@ android:layout_height="wrap_content" android:paddingLeft="20dp" android:paddingTop="20dp" - android:text="Chat with BitNet" + android:text="Chat with Llama" android:textColor="@android:color/white" android:textSize="16sp" android:textStyle="bold" /> diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 380a59bfcf1..dde0be2ba20 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,6 +13,9 @@ #include #include #include +#include +#include +#include #if defined(ET_USE_QNN_OSS_RUNNER) #include @@ -171,18 +174,35 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { temperature); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { #if defined(ET_USE_QNN_OSS_RUNNER) + // TODO: very ugly... Find better ways to set OpPackage and QuantAttrs + std::string model_path_str = model_path->toStdString(); + std::filesystem::path model_path_fs(model_path->toStdString()); + std::filesystem::path op_package_fs = model_path_fs.parent_path() / "libQnnTMANOpPackage.so"; + std::string op_package_str = op_package_fs.string() + ":TMANOpPackageInterfaceProvider:HTP"; + + std::string quant_attrs_str = model_path_fs.stem().string() + "_quant_attrs.txt"; + std::filesystem::path quant_attrs_fs = model_path_fs.parent_path() / quant_attrs_str; + + std::ifstream quant_attrs_file(quant_attrs_fs.string()); + if (!quant_attrs_file.is_open()) { + throw std::runtime_error("Failed to open quant_attrs file: " + quant_attrs_fs.string()); + } + float scale; + int32_t zero_point; + quant_attrs_file >> scale >> zero_point; + // TODO: Replace with a more robust way to set the environment variable - setenv("QNN_OP_PACKAGE_PATHS", "/data/local/tmp/llama/libQnnTMANOpPackage.so:TMANOpPackageInterfaceProvider:HTP", 1); + setenv("QNN_OP_PACKAGE_PATHS", op_package_str.c_str(), 1); runner_ = std::make_unique( std::vector{model_path->toStdString().c_str()}, tokenizer_path->toStdString().c_str(), "", - 0.0012801217380911112f, // TODO: replace hardcoded values - 34183, + scale, + zero_point, temperature, 0, "ShiftPointer", - 2); + 1); #else if (data_path != nullptr) { runner_ = std::make_unique( From 908124676d657d9cd5a3ed1267344bc9e2e5c934 Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Sun, 11 May 2025 17:01:37 +0000 Subject: [PATCH 04/11] QNN Backend: revert io_manager --- .../oss_scripts/llama/qnn_llama_runner.cpp | 7 +- .../oss_scripts/llama/runner/io_manager.cpp | 111 ++++++++---------- .../oss_scripts/llama/runner/io_manager.h | 17 ++- .../oss_scripts/llama/runner/runner.cpp | 78 +++--------- .../oss_scripts/llama/runner/runner.h | 12 -- 5 files changed, 73 insertions(+), 152 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 7457db88769..f23cf2ec44a 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -58,10 +58,6 @@ DEFINE_string( "How to update kv cache. Choose between SmartMask and ShiftPointer", "SmartMask"); DEFINE_int32(num_iters, 1, "total num of iterations to run."); -DEFINE_string( - kv_type, - "Type of kv cache. Choose between uint8 and float32", - "uint8"); std::vector CollectPrompts(int argc, char** argv) { // Collect all prompts from command line, example usage: @@ -89,8 +85,7 @@ int main(int argc, char** argv) { FLAGS_temperature, FLAGS_eval_mode, FLAGS_kv_updater, - FLAGS_num_iters, - FLAGS_kv_type); + FLAGS_num_iters); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp index 9215ddc745b..c2bf7b04fbb 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.cpp @@ -705,8 +705,7 @@ void ShiftPointerIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { ptr->kv_attention_mask[kv_cache_len_] = 65535; } -template -SmartMaskIoMgr::SmartMaskIoMgr( +SmartMaskIoMgr::SmartMaskIoMgr( std::vector>& modules, int32_t context_len, int32_t prefill_ar_len, @@ -769,8 +768,7 @@ SmartMaskIoMgr::SmartMaskIoMgr( new IO, [](void* ptr) { delete static_cast(ptr); }); } -template -std::unordered_map SmartMaskIoMgr::get_io_elements() { +std::unordered_map SmartMaskIoMgr::get_io_elements() { int32_t max_ar_len = std::max(kv_ar_len_, prefill_ar_len_); size_t cache_in_ele = num_layers_ * num_heads_ * head_dim_ * kv_cache_len_; size_t cache_out_ele = num_layers_ * num_heads_ * head_dim_ * max_ar_len; @@ -787,8 +785,7 @@ std::unordered_map SmartMaskIoMgr::get_io_e {"prefill_logits_ele", prefill_ar_len_ * vocab_size_}}; } -template -std::unordered_map SmartMaskIoMgr::get_io_bytes() { +std::unordered_map SmartMaskIoMgr::get_io_bytes() { std::unordered_map element_map = get_io_elements(); auto align = [](size_t byte) { size_t alignment = MemoryAllocator::kDefaultAlignment; @@ -802,25 +799,24 @@ std::unordered_map SmartMaskIoMgr::get_io_b align(element_map["kv_input_toks_ele"] * sizeof(int32_t))}, {"kv_input_pos_bytes", align(element_map["kv_input_pos_ele"] * sizeof(int32_t))}, - {"cache_in_bytes", align(element_map["cache_in_ele"] * sizeof(KVType))}, + {"cache_in_bytes", align(element_map["cache_in_ele"] * sizeof(uint8_t))}, {"cache_out_bytes", - align(element_map["cache_out_ele"] * sizeof(KVType))}, + align(element_map["cache_out_ele"] * sizeof(uint8_t))}, {"kv_attention_mask_bytes", - align(element_map["kv_attention_mask_ele"] * sizeof(IOType))}, + align(element_map["kv_attention_mask_ele"] * sizeof(uint16_t))}, {"kv_logits_bytes", - align(element_map["kv_logits_ele"] * sizeof(IOType))}, + align(element_map["kv_logits_ele"] * sizeof(uint16_t))}, {"prefill_input_toks_bytes", align(element_map["prefill_input_toks_ele"] * sizeof(int32_t))}, {"prefill_input_pos_bytes", align(element_map["prefill_input_pos_ele"] * sizeof(int32_t))}, {"prefill_attention_mask_bytes", - align(element_map["prefill_attention_mask_ele"] * sizeof(IOType))}, + align(element_map["prefill_attention_mask_ele"] * sizeof(uint16_t))}, {"prefill_logits_bytes", - align(element_map["prefill_logits_ele"] * sizeof(IOType))}}; + align(element_map["prefill_logits_ele"] * sizeof(uint16_t))}}; } -template -void SmartMaskIoMgr::IO::init_io_ptrs( +void SmartMaskIoMgr::IO::init_io_ptrs( void* shared_buffer_ptr, std::unordered_map& io_bytes_map) { shared_buffer_base = shared_buffer_ptr; @@ -846,11 +842,11 @@ void SmartMaskIoMgr::IO::init_io_ptrs( k_cache_ref[i].reserve(num_heads_); v_cache_ref[i].reserve(num_heads_); for (int j = 0; j < num_heads_; ++j) { - k_cache_ref[i][j] = reinterpret_cast(cur_ptr); + k_cache_ref[i][j] = reinterpret_cast(cur_ptr); io_pos_map[cur_ptr] = cur_pos; cur_ptr += single_head_size; cur_pos += single_head_size; - v_cache_ref[i][j] = reinterpret_cast(cur_ptr); + v_cache_ref[i][j] = reinterpret_cast(cur_ptr); io_pos_map[cur_ptr] = cur_pos; cur_ptr += single_head_size; cur_pos += single_head_size; @@ -858,17 +854,17 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } continue; } else if (key == "kv_attention_mask_bytes") { - kv_attention_mask = reinterpret_cast(cur_ptr); + kv_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "kv_logits_bytes") { - kv_logits = reinterpret_cast(cur_ptr); + kv_logits = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_toks_bytes") { prefill_input_toks = reinterpret_cast(cur_ptr); } else if (key == "prefill_input_pos_bytes") { prefill_input_pos = reinterpret_cast(cur_ptr); } else if (key == "prefill_attention_mask_bytes") { - prefill_attention_mask = reinterpret_cast(cur_ptr); + prefill_attention_mask = reinterpret_cast(cur_ptr); } else if (key == "prefill_logits_bytes") { - prefill_logits = reinterpret_cast(cur_ptr); + prefill_logits = reinterpret_cast(cur_ptr); } else { ET_LOG(Error, "Unknown pointer type: %s", key.c_str()); } @@ -879,8 +875,7 @@ void SmartMaskIoMgr::IO::init_io_ptrs( } } -template -void SmartMaskIoMgr::IO::add_custom_mem_info( +void SmartMaskIoMgr::IO::add_custom_mem_info( void* ptr, size_t nbytes, executorch::aten::ScalarType scalar_type, @@ -897,8 +892,7 @@ void SmartMaskIoMgr::IO::add_custom_mem_info( QnnExecuTorchAddCustomMemTensorInfo(info); } -template -void SmartMaskIoMgr::init_io() { +void SmartMaskIoMgr::init_io() { std::unordered_map io_bytes_map = get_io_bytes(); switch (eval_mode_) { @@ -937,8 +931,7 @@ void SmartMaskIoMgr::init_io() { ptr->init_io_ptrs(shared_ptr, io_bytes_map); } -template -void SmartMaskIoMgr::reset_io( +void SmartMaskIoMgr::reset_io( const std::vector>& prefill_methods_meta, const std::vector< @@ -954,8 +947,7 @@ void SmartMaskIoMgr::reset_io( std::fill(ptr->kv_attention_mask, ptr->kv_attention_mask + kv_attn_size, 0); } -template -void SmartMaskIoMgr::prepare_kv_io( +void SmartMaskIoMgr::prepare_kv_io( const std::vector>& methods_meta) { for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( @@ -1028,7 +1020,7 @@ void SmartMaskIoMgr::prepare_kv_io( std::vector>& cache = (cache_group == 0 ? k_cache_in_[kv_forward_name_] : v_cache_in_[kv_forward_name_]); - KVType* cache_ptr = (cache_group == 0) + uint8_t* cache_ptr = (cache_group == 0) ? ptr->k_cache[layer + offset][head] : ptr->v_cache[layer + offset][head]; @@ -1082,7 +1074,7 @@ void SmartMaskIoMgr::prepare_kv_io( std::vector>& cache = (cache_group == 0 ? k_cache_out_[kv_forward_name_] : v_cache_out_[kv_forward_name_]); - KVType* cache_ptr = (cache_group == 0) + uint8_t* cache_ptr = (cache_group == 0) ? ptr->k_cache_out[layer + offset][head] : ptr->v_cache_out[layer + offset][head]; cache.emplace_back(std::make_unique( @@ -1105,8 +1097,7 @@ void SmartMaskIoMgr::prepare_kv_io( } } -template -void SmartMaskIoMgr::update_kv_io( +void SmartMaskIoMgr::update_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1124,16 +1115,16 @@ void SmartMaskIoMgr::update_kv_io( auto& v_cache_out = v_cache_out_[kv_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive for (int i = 0; i < v_cache_in.size(); ++i) { - KVType* ptr_in = v_cache_in[i]->mutable_data() + pos * head_dim_; - const KVType* ptr_out = v_cache_out[i]->data(); - memcpy(ptr_in, ptr_out, head_dim_ * sizeof(KVType)); + uint8_t* ptr_in = v_cache_in[i]->mutable_data() + pos * head_dim_; + const uint8_t* ptr_out = v_cache_out[i]->data(); + memcpy(ptr_in, ptr_out, head_dim_ * sizeof(uint8_t)); } auto& k_cache_in = k_cache_in_[kv_forward_name_]; auto& k_cache_out = k_cache_out_[kv_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - KVType* ptr_in = k_cache_in[i]->mutable_data() + pos; - const KVType* ptr_out = k_cache_out[i]->data(); + uint8_t* ptr_in = k_cache_in[i]->mutable_data() + pos; + const uint8_t* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = 0; j < head_dim_; ++j, offset += kv_cache_len_) { ptr_in[offset] = ptr_out[j]; @@ -1141,8 +1132,7 @@ void SmartMaskIoMgr::update_kv_io( } } -template -void SmartMaskIoMgr::prepare_prefill_io( +void SmartMaskIoMgr::prepare_prefill_io( const std::vector>& methods_meta) { for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( @@ -1236,7 +1226,7 @@ void SmartMaskIoMgr::prepare_prefill_io( std::vector>& cache = (cache_group == 0 ? k_cache_in_[prefill_forward_name_] : v_cache_in_[prefill_forward_name_]); - KVType* cache_ptr = (cache_group == 0) + uint8_t* cache_ptr = (cache_group == 0) ? ptr->k_cache[layer + offset][head] : ptr->v_cache[layer + offset][head]; @@ -1313,8 +1303,7 @@ void SmartMaskIoMgr::prepare_prefill_io( } } -template -void SmartMaskIoMgr::update_prefill_to_kv_io( +void SmartMaskIoMgr::update_prefill_to_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1333,18 +1322,18 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( auto& v_cache_in = v_cache_in_[kv_forward_name_]; auto& v_cache_out = v_cache_out_[prefill_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive - size_t copied_size = kv_cache_len_ * head_dim_ * sizeof(KVType); + size_t copied_size = kv_cache_len_ * head_dim_ * sizeof(uint8_t); for (int i = 0; i < v_cache_in.size(); ++i) { - KVType* ptr_in = v_cache_in[i]->mutable_data(); - const KVType* ptr_out = v_cache_out[i]->data(); + uint8_t* ptr_in = v_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = v_cache_out[i]->data(); memcpy(ptr_in, ptr_out, copied_size); } auto& k_cache_in = k_cache_in_[kv_forward_name_]; auto& k_cache_out = k_cache_out_[prefill_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - KVType* ptr_in = k_cache_in[i]->mutable_data(); - const KVType* ptr_out = k_cache_out[i]->data(); + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = 0; j < head_dim_; ++j, offset += kv_cache_len_) { for (size_t k = 0, k_stride = j * prefill_ar_len_; k < pos; k++) { @@ -1354,10 +1343,10 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( } } else { // Update K is enough, copy from last to prevent from overwriting values - size_t copied_size = pos * sizeof(KVType); + size_t copied_size = pos * sizeof(uint8_t); for (int l = 0; l < num_layers_; l++) { for (int h = 0; h < num_heads_; h++) { - KVType* k_cache = ptr->k_cache[l][h]; + uint8_t* k_cache = ptr->k_cache[l][h]; for (int hd = head_dim_ - 1; hd > -1; hd--) { memcpy( k_cache + (kv_cache_len_ * hd), @@ -1369,8 +1358,7 @@ void SmartMaskIoMgr::update_prefill_to_kv_io( } } -template -void SmartMaskIoMgr::update_prefill_io( +void SmartMaskIoMgr::update_prefill_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { @@ -1381,19 +1369,19 @@ void SmartMaskIoMgr::update_prefill_io( auto& v_cache_in = v_cache_in_[prefill_forward_name_]; auto& v_cache_out = v_cache_out_[prefill_forward_name_]; // update v_cache by single thread, this part is cpu cache sensitive - size_t copied_size = prefill_ar_len_ * head_dim_ * sizeof(KVType); + size_t copied_size = prefill_ar_len_ * head_dim_ * sizeof(uint8_t); for (int i = 0; i < v_cache_in.size(); ++i) { - KVType* ptr_in = - v_cache_in[i]->mutable_data() + pos * head_dim_; - const KVType* ptr_out = v_cache_out[i]->data(); + uint8_t* ptr_in = + v_cache_in[i]->mutable_data() + pos * head_dim_; + const uint8_t* ptr_out = v_cache_out[i]->data(); memcpy(ptr_in, ptr_out, copied_size); } auto& k_cache_in = k_cache_in_[prefill_forward_name_]; auto& k_cache_out = k_cache_out_[prefill_forward_name_]; for (int i = 0; i < k_cache_in.size(); ++i) { - KVType* ptr_in = k_cache_in[i]->mutable_data(); - const KVType* ptr_out = k_cache_out[i]->data(); + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); for (size_t j = 0, offset = pos; j < head_dim_; ++j, offset += prefill_cache_len_) { for (size_t k = 0, k_stride = j * prefill_ar_len_; k < prefill_ar_len_; @@ -1405,8 +1393,7 @@ void SmartMaskIoMgr::update_prefill_io( } } -template -void SmartMaskIoMgr::fill_prefill_toks( +void SmartMaskIoMgr::fill_prefill_toks( int64_t start_pos, std::vector& prompt_tokens) { IO* ptr = static_cast(get_mutable_ptr()); @@ -1438,15 +1425,11 @@ void SmartMaskIoMgr::fill_prefill_toks( } } -template -void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { +void SmartMaskIoMgr::fill_kv_tok_mask(int64_t pos, int64_t cur_token) { IO* ptr = static_cast(get_mutable_ptr()); *ptr->kv_input_toks = use_int64_token_ ? cur_token : static_cast(cur_token); ptr->kv_attention_mask[kv_cache_len_] = 65535; } -template class SmartMaskIoMgr; -template class SmartMaskIoMgr; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h index 7826d2f81df..0f10eef8ddc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/io_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/io_manager.h @@ -196,7 +196,6 @@ class ShiftPointerIoMgr : public IoMgrBase { const bool use_int64_token_{false}; }; -template class SmartMaskIoMgr : public IoMgrBase { public: SmartMaskIoMgr( @@ -258,22 +257,22 @@ class SmartMaskIoMgr : public IoMgrBase { int64_t* kv_input_toks; int32_t* kv_input_pos; // layer -> head -> head_dim * seq_len - std::vector> k_cache; - std::vector> v_cache; + std::vector> k_cache; + std::vector> v_cache; // layer -> head -> head_dim - std::vector> k_cache_out; - std::vector> v_cache_out; + std::vector> k_cache_out; + std::vector> v_cache_out; // kv_ar_len_ * context_len_ - IOType* kv_attention_mask; + uint16_t* kv_attention_mask; // kv_ar_len_ * vocab_size - IOType* kv_logits; + uint16_t* kv_logits; // prefill_ar_len_ int64_t* prefill_input_toks; int32_t* prefill_input_pos; // prefill_ar_len_ * context_len_ - IOType* prefill_attention_mask; + uint16_t* prefill_attention_mask; // vocab_size * prefill_ar_len_ - IOType* prefill_logits; + uint16_t* prefill_logits; size_t num_layers_; size_t num_heads_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 09252f83b96..40e34fbd1ac 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -51,8 +51,7 @@ Runner::Runner( const float temperature, const int eval_mode, const std::string& kv_updater, - const int num_iters, - const std::string& kv_type) + const int num_iters) : n_bos_(1), n_eos_(1), tokenizer_path_(tokenizer_path), @@ -62,8 +61,7 @@ Runner::Runner( temperature_(temperature), eval_mode_(static_cast(eval_mode)), kv_updater_(kv_updater), - num_iters_(num_iters), - kv_type_(kv_type) { + num_iters_(num_iters) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); @@ -73,28 +71,6 @@ Runner::Runner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -Runner::Runner( - const std::vector& models_path, - const std::string& tokenizer_path, - const std::string& performance_output_path_, - const float logits_scale, - const int32_t logits_offset, - const float temperature, - const int eval_mode, - const std::string& kv_updater, - const int num_iters) - : Runner( - models_path, - tokenizer_path, - performance_output_path_, - logits_scale, - logits_offset, - temperature, - eval_mode, - kv_updater, - num_iters, - "uint8") { } - bool Runner::is_loaded() const { bool loaded = true; for (const std::shared_ptr& module : modules_) { @@ -164,41 +140,21 @@ Error Runner::load() { ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); if (kv_updater_ == "SmartMask") { - if (kv_type_ == "uint8") { - io_mgr_ = std::make_unique>( - modules_, - context_len_, - prefill_ar_len_, - prefill_cache_len_, - kv_ar_len_, - kv_cache_len_, - vocab_size_, - num_layers, - head_dim, - num_heads, - eval_mode_, - prefill_forward_name_, - kv_forward_name_, - use_int64_token_); - } else if (kv_type_ == "float32") { - io_mgr_ = std::make_unique>( - modules_, - context_len_, - prefill_ar_len_, - prefill_cache_len_, - kv_ar_len_, - kv_cache_len_, - vocab_size_, - num_layers, - head_dim, - num_heads, - eval_mode_, - prefill_forward_name_, - kv_forward_name_, - use_int64_token_); - } else { - ET_LOG(Error, "Using an unknown kv type %s", kv_type_.c_str()); - } + io_mgr_ = std::make_unique( + modules_, + context_len_, + prefill_ar_len_, + prefill_cache_len_, + kv_ar_len_, + kv_cache_len_, + vocab_size_, + num_layers, + head_dim, + num_heads, + eval_mode_, + prefill_forward_name_, + kv_forward_name_, + use_int64_token_); } else if (kv_updater_ == "ShiftPointer") { io_mgr_ = std::make_unique( modules_, diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index c2df018870e..44162396ad6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -27,18 +27,6 @@ namespace example { class Runner : public executorch::extension::llm::IRunner { public: - explicit Runner( - const std::vector& models_path, - const std::string& tokenizer_path, - const std::string& performance_output_path_, - const float logits_scale, - const int32_t logits_offset, - const float temperature, - const int eval_mode, - const std::string& kv_updater, - const int num_iters, - const std::string& kv_type); - explicit Runner( const std::vector& models_path, const std::string& tokenizer_path, From a8adabe2f68bfda3f551fccc63315b5ab45472eb Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Tue, 13 May 2025 03:54:36 +0000 Subject: [PATCH 05/11] QNN Backend: add GPTQ support for llama3 and qwen3 (TODO: refactor) --- .../qualcomm/oss_scripts/llama3/README.md | 1 + .../qualcomm/oss_scripts/llama3/llama3.py | 1198 ++++++++++++++++ .../oss_scripts/llama3/model/__init__.py | 0 .../llama3/model/configuration_llama3.py | 80 ++ .../oss_scripts/llama3/model/static_llama3.py | 549 ++++++++ examples/qualcomm/oss_scripts/qwen3/README.md | 1 + .../oss_scripts/qwen3/model/__init__.py | 0 .../qwen3/model/configuration_qwen3.py | 80 ++ .../oss_scripts/qwen3/model/static_qwen3.py | 555 ++++++++ examples/qualcomm/oss_scripts/qwen3/qwen3.py | 1201 +++++++++++++++++ 10 files changed, 3665 insertions(+) create mode 100644 examples/qualcomm/oss_scripts/llama3/README.md create mode 100644 examples/qualcomm/oss_scripts/llama3/llama3.py create mode 100644 examples/qualcomm/oss_scripts/llama3/model/__init__.py create mode 100644 examples/qualcomm/oss_scripts/llama3/model/configuration_llama3.py create mode 100644 examples/qualcomm/oss_scripts/llama3/model/static_llama3.py create mode 100644 examples/qualcomm/oss_scripts/qwen3/README.md create mode 100644 examples/qualcomm/oss_scripts/qwen3/model/__init__.py create mode 100644 examples/qualcomm/oss_scripts/qwen3/model/configuration_qwen3.py create mode 100644 examples/qualcomm/oss_scripts/qwen3/model/static_qwen3.py create mode 100644 examples/qualcomm/oss_scripts/qwen3/qwen3.py diff --git a/examples/qualcomm/oss_scripts/llama3/README.md b/examples/qualcomm/oss_scripts/llama3/README.md new file mode 100644 index 00000000000..3cc3afceb22 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama3/README.md @@ -0,0 +1 @@ +TODO: refactor qwen3, llama and bitnet diff --git a/examples/qualcomm/oss_scripts/llama3/llama3.py b/examples/qualcomm/oss_scripts/llama3/llama3.py new file mode 100644 index 00000000000..509874f4a20 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama3/llama3.py @@ -0,0 +1,1198 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: reenable pyre after fixing the issues +# pyre-ignore-all-errors + +import copy +import getpass +import json +import logging +import os +import subprocess +import sys +import time +from functools import partial +from multiprocessing.connection import Client + +import torch +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO +from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) + +from executorch.backends.qualcomm.builders.utils import is_graph_output + +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_linear_16a8w_in_affine_layer, + annotate_matmul_16a8w, + annotate_prefill_kv_output, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset + +from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, + option_to_flatbuffer, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS_MAP, +) +from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + convert_linear_to_qlinear, + convert_qlinear_to_tman_linear, + convert_qlinear_to_linear, + generate_composite_llama_program, + generate_htp_compiler_spec, + generate_multi_graph_program, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, + update_spill_fill_size, +) + +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) +from executorch.examples.qualcomm.oss_scripts.llama3.model.static_llama3 import ( + Llama3ForCausalLM, + Llama3Config, + Llama3DecoderLayer, +) +from executorch.examples.qualcomm.utils import ( + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.extension.llm.custom_ops import model_sharding +from executorch.extension.llm.export.builder import DType +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer, HuggingFaceTokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer + +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from gptqmodel import GPTQModel +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + +sys.setrecursionlimit(4096) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def smart_mask_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + for i, k_cache in enumerate(k_caches): + k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0] + + for i, v_cache in enumerate(v_caches): + v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :] + atten_mask[:, :, pos - ar_len] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def shift_pointer_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + atten_mask[:, :, -pos - 1] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def _kv_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + updater=smart_mask_updater, + use_i64_token=False, +): + _, atten_mask, _, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0) + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, TiktokenTokenizer): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) if len(token_list) < ar_len else ar_len + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor( + token_list[pos - ar_len : pos], dtype=dtype + ).reshape(1, -1) + tmp_pos = all_pos[:, pos - ar_len : pos] + tmp_atten_mask = atten_mask + if pos < ar_len: + tmp_token_list = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=dtype), + torch.tensor(token_list, dtype=dtype).reshape(1, -1), + ], + dim=1, + ) + tmp_pos = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=torch.int32), + all_pos[:, :pos], + ], + dim=1, + ) + tmp_atten_mask = torch.cat( + [ + torch.ones(1, ar_len, max_seq_len - pos) * -255.0, + atten_mask[:, :, -pos:], + ], + dim=-1, + ) + + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + tmp_atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) + atten_mask, pos, k_caches, v_caches = updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + ) + if pos > len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"kv calibration data:\n{tokenizer.decode(token_list)}") + + +def _prefill_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + max_seq_len=512, + use_i64_token=False, +): + _, atten_mask = example_inputs + + # TODO: change criteria & support batch inputs if necessary + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, TiktokenTokenizer): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1) + if pos < max_seq_len: + tmp_token_list = torch.cat( + [ + tmp_token_list, + torch.zeros((1, max_seq_len - pos), dtype=dtype), + ], + dim=1, + ) + results = module( + tmp_token_list, + atten_mask, + ) + if len(results) == 3: + logits, new_k_caches, new_v_caches = results + elif len(results) == 1: + logits = results + token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item()) + pos += 1 + + print(f"prefill calibration data:\n{tokenizer.decode(token_list)}") + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + kv_updater=smart_mask_updater, + use_i64_token=False, +): + if len(example_inputs) == 2: + _prefill_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + max_seq_len, + use_i64_token, + ) + elif len(example_inputs) == 5: + _kv_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + ar_len, + max_seq_len, + updater=kv_updater, + use_i64_token=use_i64_token, + ) + else: + raise RuntimeError("Get wrong inputs") + + +class SingleLlama: + def __init__(self, llama_model, pte_filename) -> None: + super().__init__() + self.llama_model = llama_model + self.passes_job = get_capture_program_passes() + self.dep_table = get_passes_dependency_for_capture_program() + self.quant_attrs = None + self.quant_dtype = None + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + self.pte_filename = pte_filename + if self.llama_meta["get_use_kv_cache"]: + tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( + use_kv_cache=True + ) + self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + else: + tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) + self.inputs = (tokens, atten_mask) + self.llama_graph_module = llama_model + self.io_shape = { + # logit output + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_vocab_size"], + ), + } + + def _tag_ios(self, node, fixed_point_type): + if not self.has_quant_io: + return + + # shape of k caches and v caches + kv_cache_shape = { + # single head, kv input + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + # single head, kv output + (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), + } + + atten_mask_shape = { + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_max_seq_len"], + ), + } + + freq_shape = { + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"] // 2), + } + + freq_op = { + exir_ops.edge.aten.select.int, + } + quant_io_type = None + + if node.op == "placeholder": + if ( + len(users := list(node.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + elif node.meta["val"].size() in atten_mask_shape: + quant_io_type = fixed_point_type["io_type"] + if is_graph_output(node): + if node.meta["val"].size()[-2:] in kv_cache_shape: + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + quant_io_type = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if node.target in freq_op and node.meta["val"].size() in freq_shape: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type + + def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + fx_graph_module = None + + with torch.no_grad(): + fx_graph_module = torch.export.export( + self.llama_graph_module, self.inputs, strict=True + ).module() + + if QuantDtype == QuantDtype.use_16a4w_block: + conv_nodes = [ + n for n in fx_graph_module.graph.nodes if "conv" in n.name + ] + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + quantizer.set_block_size_map(block_size_map) + + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + + logging.info("Quantizing the model...") + calibrate( + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), + args.prompt, + fx_graph_module, + tokenizer=tokenizer, + ar_len=self.llama_meta["get_ar_len"], + max_seq_len=self.llama_meta["get_max_seq_len"], + kv_updater=args.kv_updater, + use_i64_token=args.embedding_quantize is not None, + ) + + self.llama_graph_module = convert_pt2e(fx_graph_module) + + def lowering_modules( + self, + work_space, + use_fp16=False, + soc_model=QcomChipset.SM8650, + num_sharding=1, + shared_buffer=False, + verbose=False, + ): + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + with torch.no_grad(): + # backend option + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=num_sharding > 1 + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + shared_buffer=shared_buffer, + ) + skip_node_op_set = {"llama.fallback.default"} + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + self.llama_graph_module, + self.inputs, + compiler_specs, + constant_methods=self.llama_meta, + dep_table=self.dep_table, + passes_job=self.passes_job, + skip_node_op_set=skip_node_op_set, + ) + + for n in edge_prog_mgr.exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in self.io_shape: + self.quant_attrs = output_encoding + + if num_sharding > 1: + update_spill_fill_size(edge_prog_mgr.exported_program()) + + if verbose: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) + + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + def get_example_inputs(self, use_kv_cache=True): + return self.llama_model.get_example_inputs(use_kv_cache) + + def get_quant_attrs(self): + return self.quant_attrs + + +def compile(args, pte_filename, tokenizer): + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + + config = Llama3Config.from_pretrained(args.model_dir) + + llama_instance_list = [] + use_i64_token = args.embedding_quantize is not None + with torch.device("meta"): + if args.model_mode == "kv": + llama_instance_list.append( + Llama3ForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + elif args.model_mode == "hybrid": + llama_instance_list.append( + Llama3ForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + llama_instance_list.append( + Llama3ForCausalLM( + config, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + qmodel = GPTQModel.from_quantized(args.model_dir) + state_dict = qmodel.model.state_dict() + qcfg = qmodel.quantize_config + if qcfg.desc_act: + raise RuntimeError( + "desc_act=True is unsupported right now." + ) + qlinear_cls = partial( + TorchQuantLinear, + bits=qcfg.bits, + group_size=qcfg.group_size, + desc_act=qcfg.desc_act, + sym=qcfg.sym, + pack_dtype=qcfg.pack_dtype, + device=qcfg.device, + adapter=qcfg.adapter, + ) + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + layer: Llama3DecoderLayer + convert_linear_to_qlinear(layer.self_attn, qlinear_cls) + convert_linear_to_qlinear(layer.mlp, qlinear_cls) + + for llama_instance in llama_instance_list: + incompatible_keys = llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + assert len(incompatible_keys.missing_keys) <= 1 and len(incompatible_keys.unexpected_keys) == 0 + if "lm_head.weight" in incompatible_keys.missing_keys: + llama_instance.tie_weights() + end_load_ts = time.time() + logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") + + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + for i in range(len(llama_instance_list)): + llama_instance_list[i] = llama_instance_list[i].to( + dtype_override.to_torch_dtype() + ) + + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + if args.use_tman: + convert_qlinear_to_tman_linear(layer.self_attn) + layer.self_attn.prepare_tman() + convert_qlinear_to_tman_linear(layer.mlp) + else: + convert_qlinear_to_linear(layer.self_attn) + layer.self_attn.prepare_sha() + convert_qlinear_to_linear(layer.mlp) + + use_fp16 = True + fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} + if args.ptq: + use_fp16 = False + fixed_point_type["kv_type"] = torch.uint8 + if args.ptq == "8a8w": + fixed_point_type["io_type"] = torch.uint8 + elif args.ptq in ("16a4w", "16a4w_block"): + fixed_point_type["io_type"] = torch.uint16 + else: + assert args.ptq in [ + "8a8w", + "16a4w", + "16a4w_block", + ], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block." + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + for i in range(len(llama_instance_list)): + if args.embedding_quantize: + llama_instance_list[i] = get_quant_embedding_transform(args)( + llama_instance_list[i] + ) + llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + llama_instance_list[i] = SingleLlama( + llama_instance_list[i].eval(), pte_filename + ) + if args.embedding_quantize: + llama_instance_list[i].passes_job[I64toI32][ + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY + ]["skip_node"] = {"tokens"} + + if args.ptq: + start_quantize_ts = time.time() + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": + custom_annotations = custom_annotations + ( + annotate_linear_16a8w_in_affine_layer, + ) + kv_quant_attrs = {} + for i, llama_instance in enumerate(llama_instance_list): + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + tokenizer=tokenizer, + custom_annotations=custom_annotations, + ) + # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode == "hybrid": + output_indices = 0 + for node in llama_instance.llama_graph_module.graph.nodes: + if node.op == "output": + for output in node.args[0]: + kv_quant_attrs[output_indices] = output.args[1:] + output_indices += 1 + break + custom_annotations = custom_annotations + ( + partial( + annotate_prefill_kv_output, + kv_quant_attrs=kv_quant_attrs, + ), + ) + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type) + end_quantize_ts = time.time() + logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") + + start_lowering_ts = time.time() + quant_attrs = None + if args.num_sharding > 1: + for llama_instance in llama_instance_list: + SplitGraph, setting = model_sharding.get_split_graph_pass( + llama_instance.llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + llama_instance.passes_job[SplitGraph] = setting + llama_instance.dep_table[SplitGraph] = [FoldQDQ] + llama_instance.dep_table[TagQuantIO] = [SplitGraph] + + if args.model_mode in ["kv"]: + llama_instance_list[0].lowering_modules( + args.artifact, + use_fp16=use_fp16, + soc_model=get_soc_to_chipset_map()[args.model], + num_sharding=args.num_sharding, + shared_buffer=args.shared_buffer, + ) + quant_attrs = llama_instance_list[0].get_quant_attrs() + elif args.model_mode == "hybrid": + sample_inputs_list = [ + llama_instace.inputs for llama_instace in llama_instance_list + ] + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 + ) + graph_names = ["kv_forward", "prefill_forward"] + compiler_specs = [ + generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.model], + backend_options=backend_options, + shared_buffer=args.shared_buffer, + multiple_graphs=True, + weight_sharing=not args.enable_x86_64, # x86 emulator does not support weight sharing + graph_name=graph_name, + ) + for graph_name in graph_names + ] + skip_node_op_set = {"llama.fallback.default"} + edge_prog_mgrs = [ + to_edge_transform_and_lower_to_qnn( + llama_instance.llama_graph_module, + sample_input, + compile_spec, + dep_table=llama_instance.dep_table, + passes_job=llama_instance.passes_job, + skip_node_op_set=skip_node_op_set, + ) + for llama_instance, sample_input, compile_spec in zip( + llama_instance_list, sample_inputs_list, compiler_specs + ) + ] + for n in edge_prog_mgrs[0].exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in llama_instance_list[0].io_shape: + quant_attrs = output_encoding + + if args.num_sharding > 1: + max_sf_size = update_spill_fill_size( + [edge_prog_mgr.exported_program() for edge_prog_mgr in edge_prog_mgrs] + ) + qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) + qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( + max_sf_size + ) + compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) + + if args.verbose: + for edge_prog_mgr in edge_prog_mgrs: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) + + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + + bundle_progs_list = [] + lower_module_dict = {name: [] for name in graph_names} + call_delegate_inputs_dict = {name: [] for name in graph_names} + call_delegate_node_name_dict = {name: [] for name in graph_names} + outputs_dict = {name: [] for name in graph_names} + input_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + for node in prog.exported_program().graph_module.graph.nodes: + if ( + node.op == "call_function" + and "executorch_call_delegate" in node.name + ): + call_delegate_node_name_dict[graph_name].append(node.name) + call_delegate_inputs_list = [] + for arg in node.args: + if arg.op == "call_function": + if ( + arg.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ): + call_delegate_inputs_list.append((arg.name, None)) + else: + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) + elif arg.op == "placeholder": + call_delegate_inputs_list.append((arg.name, None)) + # No extra needs to do for get_attr node + call_delegate_inputs_dict[graph_name].append( + call_delegate_inputs_list + ) + elif node.op == "output": + for arg in node.args[0]: + outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) + for num in range(args.num_sharding - 1, -1, -1): + processed_bytes = [] + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + processed_bytes.append( + getattr( + prog.exported_program().graph_module, f"lowered_module_{num}" + ).processed_bytes + ) + call_delegate_node = [ + list(node.users.keys())[0] + for node in prog.exported_program().graph_module.graph.nodes + if node.op == "get_attr" and node.name == f"lowered_module_{num}" + ] + input_nodes_dict[graph_name] = [ + node + for node in call_delegate_node[0].args + if node.op == "placeholder" + or node.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ] + prog_mgr, bundle_progs = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[0].llama_meta, # kv method meta + ) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + exec_prog = generate_composite_llama_program( + llama_model=llama_instance_list[1].llama_model, + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + embedding_quantize=args.embedding_quantize, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) + + end_lowering_ts = time.time() + logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") + return quant_attrs + + +def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + + if args.model_mode == "kv": + eval_mode = 0 + elif args.model_mode == "hybrid": + eval_mode = 1 + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + pte_path = ( + f"{pre_gen_pte}/{pte_filename}.pte" + if pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + with open(f"{args.artifact}/outputs/outputs.txt", "r") as f: + outputs.append(f.read()) + + seq_len = args.max_seq_len + runner_args = " ".join( + [ + f'--prompt "{args.prompt}"', + f"--eval_mode {eval_mode}", + f"--temperature {args.temperature}", + f"--system_prompt '{args.system_prompt}'", + f"--logits_scale {quant_attrs['scale']}", + f"--logits_offset {quant_attrs['zero_point']}", + ] + ) + + runner_cmd = "" + performance_output_path = "outputs/inference_speed.txt" + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. Check only the first few tokens. + seq_len = min(seq_len, 16) + + if args.kv_updater == smart_mask_updater: + logging.warning( + "x86 only support ShiftPointer, overwrite kv_updater to ShiftPointer" + ) + + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + f"--tokenizer_path {runtime_tokenizer_path}", + f"--model_path {pte_path}", + f"--seq_len {seq_len}", + f"--output_path {args.artifact}/outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", + f"--kv_updater ShiftPointer", + runner_args, + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + f"./qnn_llama_runner", + f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}", + f"--model_path {pte_filename}.pte", + f"--seq_len {seq_len}", + "--output_path outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", + f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", + runner_args, + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path]) + adb.execute(custom_runner_cmd=runner_cmd) + + adb.pull(output_path=args.artifact, callback=post_process) + if args.ip and args.port != -1: + inference_speed = 0 + with open(f"{args.artifact}/{performance_output_path}", "r") as f: + inference_speed = float(f.read()) + + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "result": outputs, + "pte_size": pte_size, + "inference_speed": inference_speed, + } + ) + ) + else: + for idx, output in enumerate(outputs): + logging.info(f"Results[{idx}]:\n{output}") + + +def _build_parser(): + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./llama_qnn", + default="./llama_qnn", + type=str, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", + type=str, + ) + + parser.add_argument( + "--llama_model", + choices=["stories110m", "llama3_2"], + help="The Llama model to export. Current available options are: [stories110m, llama3_2]", + required=True, + ) + + parser.add_argument( + "--model_dir", + help="Pass llama checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="For Llama2. Pass Llama2 tokenizer binary.", + required=False, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama.", + required=True, + type=str, + ) + + parser.add_argument( + "--system_prompt", + help="For Llama3. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", + default="", + type=str, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated llama in the given directory.", + type=str, + ) + + parser.add_argument( + "--num_sharding", + type=int, + default=1, + help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", + ) + + parser.add_argument( + "--model_mode", + help="Export and inference kv mode or hybrid mode", + default="kv", + choices=["kv", "hybrid"], + type=str, + ) + + parser.add_argument( + "--max_seq_len", + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + default=512, + type=int, + ) + + parser.add_argument( + "--prefill_ar_len", + help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.", + default=32, + type=int, + ) + + parser.add_argument( + "--kv_updater", + help="Choose how to update kv cache during runtime", + choices=["smart_mask", "shift_pointer"], + default="smart_mask", + type=str, + ) + + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", + ) + + parser.add_argument( + "--use_tman", + action="store_true", + help="Use TMANLinear instead of QNNConv2d.", + ) + + parser.add_argument("-v", "--verbose", action="store_true") + + return parser + + +def export_llama(args) -> None: + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + if args.model_mode == "kv": + pte_filename = "kv_llama_qnn" + elif args.model_mode == "hybrid": + assert ( + args.max_seq_len >= args.prefill_ar_len + ), "Please ensure max_seq_len is >= prefill_ar_len" + pte_filename = "hybrid_llama_qnn" + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + tokenizer = get_tokenizer(args.tokenizer_model) + runtime_tokenizer_path = "" + if args.llama_model == "stories110m": + assert isinstance( + tokenizer, SentencePieceTokenizer + ), f"Wrong tokenizer provided for stories110m." + assert ( + args.tokenizer_bin is not None + ), "Please provide tokenizer_bin for stories110m." + runtime_tokenizer_path = args.tokenizer_bin + elif args.llama_model == "llama3_2": + assert isinstance( + tokenizer, TiktokenTokenizer + ) or isinstance( + tokenizer, HuggingFaceTokenizer + ), f"Wrong tokenizer provided for llama3_2." + runtime_tokenizer_path = args.tokenizer_model + else: + raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") + + if args.kv_updater == "smart_mask": + args.shared_buffer = True + args.kv_updater = smart_mask_updater + elif args.kv_updater == "shift_pointer": + args.kv_updater = shift_pointer_updater + else: + exit(f"Using an unkown kv update {args.kv_updater}") + + if args.pre_gen_pte: + quant_attrs = json.load( + open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt") + ) + inference( + args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte + ) + exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") + + if args.compile_only: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + + if args.ip and args.port != -1: + pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "pte_size": pte_size, + } + ) + ) + exit(f"Finish compile_only and save to {args.artifact}") + + try: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + logging.info( + f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}" + ) + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + inference(args, quant_attrs, pte_filename, runtime_tokenizer_path) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) + + +def main(): + parser = _build_parser() + args = parser.parse_args() + args.prompt = "<|start_header_id|>user<|end_header_id|>\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n".format(args.prompt) + export_llama(args) + + +# flake8: noqa: C901 +if __name__ == "__main__": + main() diff --git a/examples/qualcomm/oss_scripts/llama3/model/__init__.py b/examples/qualcomm/oss_scripts/llama3/model/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/qualcomm/oss_scripts/llama3/model/configuration_llama3.py b/examples/qualcomm/oss_scripts/llama3/model/configuration_llama3.py new file mode 100644 index 00000000000..a4b2a737299 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama3/model/configuration_llama3.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Llama3 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Llama3Config(PretrainedConfig): + + model_type = "llama3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128256, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=30, + num_attention_heads=20, + num_key_value_heads=5, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=128000, + eos_token_id=128001, + tie_word_embeddings=False, + rope_theta=500000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Llama3Config"] diff --git a/examples/qualcomm/oss_scripts/llama3/model/static_llama3.py b/examples/qualcomm/oss_scripts/llama3/model/static_llama3.py new file mode 100644 index 00000000000..2feedfe6351 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama3/model/static_llama3.py @@ -0,0 +1,549 @@ +from typing import Callable, Optional, Tuple, Union, List + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.utils import logging +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel + +from .configuration_llama3 import Llama3Config +from executorch.examples.models.llama.rope import precompute_freqs_cis + + +logger = logging.get_logger(__name__) + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # broadcast for batch_prefill mode input x + if x.dim() == 4: + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class Llama3MLP(nn.Module): + def __init__(self, config: Llama3Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Llama3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Llama3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.dim = config.hidden_size + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.n_heads = config.num_attention_heads + self.n_kv_heads = config.num_key_value_heads + self.output_new_cache_only = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_softmax = torch.nn.Softmax(dim=-1) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + + output_kh, output_vh, output_y = [], [], [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + # batch_prefill mode + else: + kh = k + vh = v + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[:, cache_idx, :, :] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[:, :, cache_idx, :] + + output_y.append(y) + + for i in range(self.n_kv_heads): + if self.output_new_cache_only: + output_kh.append(k[:, i, :, -1]) + output_vh.append(v[:, -1, i, :]) + else: + output_kh.append(k[:, i, :, :]) + output_vh.append(v[:, :, i, :]) + + y = torch.concat(output_y, dim=-1) + y = self.o_proj(y) + + return y, output_kh, output_vh + + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False) + + self.forward_mha = self.forward + self.forward = self.forward_sha + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.q_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + for i in range(self.n_kv_heads): + self.wk_sha[i].weight.data.copy_( + self.k_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wv_sha[i].weight.data.copy_( + self.v_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wo_sha.weight.data.copy_(self.o_proj.weight[:, :, None, None]) + + def prepare_tman(self): + self.forward_mha = self.forward + self.forward = self.forward_tman + + def forward_tman( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states).reshape(bsz, seq_len, self.n_heads, self.head_dim) + k = self.k_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = self.v_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + # Use split_with_sizes as only split_with_sizes is supported + q = torch.split_with_sizes(q, [1 for _ in range(self.n_heads)], dim=2) + k = torch.split_with_sizes(k, [1 for _ in range(self.n_kv_heads)], dim=1) + v = torch.split_with_sizes(v, [1 for _ in range(self.n_kv_heads)], dim=2) + q = [t.squeeze(2) for t in q] + k = [t.squeeze(1) for t in k] + v = [t.squeeze(2) for t in v] + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.o_proj(y) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + hidden_states = torch.reshape( + hidden_states, (bsz, seq_len, 1, self.dim) + ).transpose(1, 3) + q = [ + wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wq_sha in self.wq_sha + ] + k = [ + wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wk_sha in self.wk_sha + ] + v = [ + wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wv_sha in self.wv_sha + ] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = y.reshape(bsz, seq_len, 1, -1) + y = y.transpose(1, 3) + y = self.wo_sha(y) + y = y.transpose(1, 3) + y = y.reshape(bsz, seq_len, -1) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + +class Llama3DecoderLayer(nn.Module): + def __init__(self, config: Llama3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Llama3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Llama3MLP(config) + self.input_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, k_cache, v_cache = self.self_attn( + hidden_states=self.input_layernorm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.mlp(self.post_attention_layernorm(h)) + return output, k_cache, v_cache + + +class Llama3Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Llama3DecoderLayer`] + + Args: + config: Llama3Config + """ + + def __init__(self, config: Llama3Config, max_seq_len: int): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Llama3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_seq_len = max_seq_len + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.use_scaled_rope = True + self.rope_scale_factor = config.rope_scaling["factor"] + else: + self.use_scaled_rope = False + self.rope_scale_factor = None + freqs_cos, freqs_sin = precompute_freqs_cis( + self.head_dim, + self.max_seq_len, + config.rope_theta, + self.use_scaled_rope, + self.rope_scale_factor, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + self.use_kv_cache = True + self.n_layers = config.num_hidden_layers + self.n_kv_heads = config.num_key_value_heads + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = ( + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos + ) + freqs_sin = ( + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin + ) + + hidden_states = self.embed_tokens(tokens) + for ind, decoder_layer in enumerate(self.layers): + k_caches = None + v_caches = None + if self.use_kv_cache: + offset_k = ind * self.n_kv_heads + offset_v = self.n_layers * self.n_kv_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_kv_heads] + v_caches = args[offset_v : offset_v + self.n_kv_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + + return hidden_states, output_k_cache, output_v_cache + + +class Llama3ForCausalLM(PreTrainedModel): + + def __init__( + self, + config: Llama3Config, + ar_len: int = 1, + max_seq_len: int = 128, + use_i64_token: bool = False + ): + super().__init__(config) + self.model = Llama3Model(config, max_seq_len) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.ar_len = ar_len + self.bos_id = config.bos_token_id + self.eos_id = config.eos_token_id + self.dim = config.hidden_size + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_batch_size = 1 + self.max_seq_len = max_seq_len + self.n_kv_heads = config.num_key_value_heads + self.n_layers = config.num_hidden_layers + self.use_kv_cache = True + self.use_i64_token = use_i64_token + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + hidden_states, output_k_cache, output_v_cache = self.model(tokens, atten_mask, input_pos, *args) + logits = self.lm_head(hidden_states) + + return logits, output_k_cache, output_v_cache + + def get_example_inputs(self, use_kv_cache=True): + dtype = torch.int64 if self.use_i64_token else torch.int32 + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype + ) + + atten_mask = torch.full((self.ar_len, self.ar_len), torch.tensor(-255.0)) + mask_cond = torch.arange(atten_mask.size(-1)) + atten_mask.masked_fill_( + mask_cond < (mask_cond + 1).view(atten_mask.size(-1), 1), 0 + ) + if self.max_seq_len != self.ar_len: + atten_mask = torch.cat( + [ + torch.ones(self.ar_len, self.max_seq_len - self.ar_len) * -255.0, + atten_mask, + ], + dim=-1, + ) + atten_mask = atten_mask[None, :, :].expand( + self.max_batch_size, self.ar_len, self.max_seq_len + ) + if use_kv_cache: + pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) + k_cache, v_cache = [], [] + + for _ in range(self.n_layers): + for _ in range(self.n_kv_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - self.ar_len, + ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - self.ar_len, + self.head_dim, + ) + ) + return ( + tokens, + atten_mask, + pos_ids, + k_cache, + v_cache, + ) + + return ( + tokens, + atten_mask, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_ar_len": self.ar_len, + "get_bos_id": self.bos_id, + "get_eos_id": self.eos_id, + "get_dim": self.dim, + "get_head_dim": self.head_dim, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_kv_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + "get_use_kv_cache": self.use_kv_cache, + } + + def get_output_embeddings(self): + return self.lm_head + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/examples/qualcomm/oss_scripts/qwen3/README.md b/examples/qualcomm/oss_scripts/qwen3/README.md new file mode 100644 index 00000000000..3cc3afceb22 --- /dev/null +++ b/examples/qualcomm/oss_scripts/qwen3/README.md @@ -0,0 +1 @@ +TODO: refactor qwen3, llama and bitnet diff --git a/examples/qualcomm/oss_scripts/qwen3/model/__init__.py b/examples/qualcomm/oss_scripts/qwen3/model/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/qualcomm/oss_scripts/qwen3/model/configuration_qwen3.py b/examples/qualcomm/oss_scripts/qwen3/model/configuration_qwen3.py new file mode 100644 index 00000000000..5d70c37edfb --- /dev/null +++ b/examples/qualcomm/oss_scripts/qwen3/model/configuration_qwen3.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class Qwen3Config(PretrainedConfig): + + model_type = "qwen3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128256, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=30, + num_attention_heads=20, + num_key_value_heads=5, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=128000, + eos_token_id=128001, + tie_word_embeddings=False, + rope_theta=500000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Qwen3Config"] diff --git a/examples/qualcomm/oss_scripts/qwen3/model/static_qwen3.py b/examples/qualcomm/oss_scripts/qwen3/model/static_qwen3.py new file mode 100644 index 00000000000..63ab3d9b045 --- /dev/null +++ b/examples/qualcomm/oss_scripts/qwen3/model/static_qwen3.py @@ -0,0 +1,555 @@ +from typing import Callable, Optional, Tuple, Union, List + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.utils import logging +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel + +from .configuration_qwen3 import Qwen3Config +from executorch.examples.models.llama.rope import precompute_freqs_cis + + +logger = logging.get_logger(__name__) + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + # broadcast for batch_prefill mode input x + if x.dim() == 4: + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class Qwen3MLP(nn.Module): + def __init__(self, config: Qwen3Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.dim = config.hidden_size + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + self.n_heads = config.num_attention_heads + self.n_kv_heads = config.num_key_value_heads + self.output_new_cache_only = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.attn_softmax = torch.nn.Softmax(dim=-1) + self.q_norm = torch.nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.k_norm = torch.nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + + output_kh, output_vh, output_y = [], [], [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + # batch_prefill mode + else: + kh = k + vh = v + for i in range(self.n_heads): + cache_idx = i // self.num_key_value_groups + + attn = q[:, :, i, :] @ kh[:, cache_idx, :, :] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[:, :, cache_idx, :] + + output_y.append(y) + + for i in range(self.n_kv_heads): + if self.output_new_cache_only: + output_kh.append(k[:, i, :, -1]) + output_vh.append(v[:, -1, i, :]) + else: + output_kh.append(k[:, i, :, :]) + output_vh.append(v[:, :, i, :]) + + y = torch.concat(output_y, dim=-1) + y = self.o_proj(y) + + return y, output_kh, output_vh + + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Conv2d(self.dim, self.head_dim, 1, bias=False) + for _ in range(self.n_kv_heads) + ] + ) + self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False) + + self.forward_mha = self.forward + self.forward = self.forward_sha + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.q_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + for i in range(self.n_kv_heads): + self.wk_sha[i].weight.data.copy_( + self.k_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wv_sha[i].weight.data.copy_( + self.v_proj.weight[ + i * self.head_dim : (i + 1) * self.head_dim, :, None, None + ] + ) + self.wo_sha.weight.data.copy_(self.o_proj.weight[:, :, None, None]) + + def prepare_tman(self): + self.forward_mha = self.forward + self.forward = self.forward_tman + + def forward_tman( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states).reshape(bsz, seq_len, self.n_heads, self.head_dim) + k = self.k_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + v = self.v_proj(hidden_states).reshape(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + # Use split_with_sizes as only split_with_sizes is supported + q = torch.split_with_sizes(q, [1 for _ in range(self.n_heads)], dim=2) + k = torch.split_with_sizes(k, [1 for _ in range(self.n_kv_heads)], dim=1) + v = torch.split_with_sizes(v, [1 for _ in range(self.n_kv_heads)], dim=2) + q = [t.squeeze(2) for t in q] + k = [t.squeeze(1) for t in k] + v = [t.squeeze(2) for t in v] + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.o_proj(y) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: Optional[List[torch.Tensor]] = None, + v_caches: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seq_len, _ = hidden_states.shape + hidden_states = torch.reshape( + hidden_states, (bsz, seq_len, 1, self.dim) + ).transpose(1, 3) + q = [ + wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wq_sha in self.wq_sha + ] + k = [ + wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wk_sha in self.wk_sha + ] + v = [ + wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + for wv_sha in self.wv_sha + ] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + for i in range(len(k)): + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_y = [] + kh, vh = [], [] + # kv cache mode + if k_caches and v_caches: + for i, _ in enumerate(k_caches): + kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) + vh.append(torch.cat([v_caches[i], v[i]], dim=1)) + # batch_prefill mode + else: + kh = k + vh = v + + for i, _ in enumerate(q): + cache_idx = i // self.num_key_value_groups + attn = q[i] @ kh[cache_idx] + attn = attn * self.scaling + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh[cache_idx] + + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = y.reshape(bsz, seq_len, 1, -1) + y = y.transpose(1, 3) + y = self.wo_sha(y) + y = y.transpose(1, 3) + y = y.reshape(bsz, seq_len, -1) + + if self.output_new_cache_only: + return y, k, v + + return y, kh, vh + + +class Qwen3DecoderLayer(nn.Module): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.input_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, k_cache, v_cache = self.self_attn( + hidden_states=self.input_layernorm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.mlp(self.post_attention_layernorm(h)) + return output, k_cache, v_cache + + +class Qwen3Model(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`] + + Args: + config: Qwen3Config + """ + + def __init__(self, config: Qwen3Config, max_seq_len: int): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = torch.nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_seq_len = max_seq_len + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.use_scaled_rope = True + self.rope_scale_factor = config.rope_scaling["factor"] + else: + self.use_scaled_rope = False + self.rope_scale_factor = None + freqs_cos, freqs_sin = precompute_freqs_cis( + self.head_dim, + self.max_seq_len, + config.rope_theta, + self.use_scaled_rope, + self.rope_scale_factor, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + self.use_kv_cache = True + self.n_layers = config.num_hidden_layers + self.n_kv_heads = config.num_key_value_heads + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = ( + self.freqs_cos[input_pos][0] if self.use_kv_cache else self.freqs_cos + ) + freqs_sin = ( + self.freqs_sin[input_pos][0] if self.use_kv_cache else self.freqs_sin + ) + + hidden_states = self.embed_tokens(tokens) + for ind, decoder_layer in enumerate(self.layers): + k_caches = None + v_caches = None + if self.use_kv_cache: + offset_k = ind * self.n_kv_heads + offset_v = self.n_layers * self.n_kv_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_kv_heads] + v_caches = args[offset_v : offset_v + self.n_kv_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + + return hidden_states, output_k_cache, output_v_cache + + +class Qwen3ForCausalLM(PreTrainedModel): + + def __init__( + self, + config: Qwen3Config, + ar_len: int = 1, + max_seq_len: int = 128, + use_i64_token: bool = False + ): + super().__init__(config) + self.model = Qwen3Model(config, max_seq_len) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.ar_len = ar_len + self.bos_id = config.bos_token_id + self.eos_id = config.eos_token_id + self.dim = config.hidden_size + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.max_batch_size = 1 + self.max_seq_len = max_seq_len + self.n_kv_heads = config.num_key_value_heads + self.n_layers = config.num_hidden_layers + self.use_kv_cache = True + self.use_i64_token = use_i64_token + + def forward( + self, + tokens: torch.Tensor, + atten_mask: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + + hidden_states, output_k_cache, output_v_cache = self.model(tokens, atten_mask, input_pos, *args) + logits = self.lm_head(hidden_states) + + return logits, output_k_cache, output_v_cache + + def get_example_inputs(self, use_kv_cache=True): + dtype = torch.int64 if self.use_i64_token else torch.int32 + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, self.ar_len), dtype=dtype + ) + + atten_mask = torch.full((self.ar_len, self.ar_len), torch.tensor(-255.0)) + mask_cond = torch.arange(atten_mask.size(-1)) + atten_mask.masked_fill_( + mask_cond < (mask_cond + 1).view(atten_mask.size(-1), 1), 0 + ) + if self.max_seq_len != self.ar_len: + atten_mask = torch.cat( + [ + torch.ones(self.ar_len, self.max_seq_len - self.ar_len) * -255.0, + atten_mask, + ], + dim=-1, + ) + atten_mask = atten_mask[None, :, :].expand( + self.max_batch_size, self.ar_len, self.max_seq_len + ) + if use_kv_cache: + pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) + k_cache, v_cache = [], [] + + for _ in range(self.n_layers): + for _ in range(self.n_kv_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - self.ar_len, + ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - self.ar_len, + self.head_dim, + ) + ) + return ( + tokens, + atten_mask, + pos_ids, + k_cache, + v_cache, + ) + + return ( + tokens, + atten_mask, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_ar_len": self.ar_len, + "get_bos_id": self.bos_id, + "get_eos_id": self.eos_id, + "get_dim": self.dim, + "get_head_dim": self.head_dim, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_kv_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + "get_use_kv_cache": self.use_kv_cache, + } + + def get_output_embeddings(self): + return self.lm_head + + def get_input_embeddings(self): + return self.model.embed_tokens diff --git a/examples/qualcomm/oss_scripts/qwen3/qwen3.py b/examples/qualcomm/oss_scripts/qwen3/qwen3.py new file mode 100644 index 00000000000..f038bc9abde --- /dev/null +++ b/examples/qualcomm/oss_scripts/qwen3/qwen3.py @@ -0,0 +1,1201 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: reenable pyre after fixing the issues +# pyre-ignore-all-errors + +import copy +import getpass +import json +import logging +import os +import subprocess +import sys +import time +from functools import partial +from multiprocessing.connection import Client + +import torch +from executorch.backends.qualcomm._passes import FoldQDQ, TagQuantIO +from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32 +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm._passes.utils import ( + get_passes_dependency_for_capture_program, +) + +from executorch.backends.qualcomm.builders.utils import is_graph_output + +from executorch.backends.qualcomm.quantizer.custom_annotation import ( + annotate_linear_16a8w_in_affine_layer, + annotate_matmul_16a8w, + annotate_prefill_kv_output, +) + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset + +from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, + option_to_flatbuffer, +) +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS_MAP, +) +from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + convert_linear_to_qlinear, + convert_qlinear_to_tman_linear, + convert_qlinear_to_linear, + generate_composite_llama_program, + generate_htp_compiler_spec, + generate_multi_graph_program, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, + update_spill_fill_size, +) + +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.models.llama.source_transformation.quantize import ( + get_quant_embedding_transform, +) +from executorch.examples.qualcomm.oss_scripts.qwen3.model.static_qwen3 import ( + Qwen3ForCausalLM, + Qwen3Config, + Qwen3DecoderLayer, +) +from executorch.examples.qualcomm.utils import ( + make_output_dir, + make_quantizer, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.extension.llm.custom_ops import model_sharding +from executorch.extension.llm.export.builder import DType +from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer, HuggingFaceTokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer + +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from gptqmodel import GPTQModel +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + +sys.setrecursionlimit(4096) +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + + +def smart_mask_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + for i, k_cache in enumerate(k_caches): + k_cache[:, :, pos - ar_len] = new_k_caches[i][:, :, 0] + + for i, v_cache in enumerate(v_caches): + v_cache[:, pos - ar_len, :] = new_v_caches[i][:, 0, :] + atten_mask[:, :, pos - ar_len] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def shift_pointer_updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches +): + # Update the KV cache input for the next inference when the position exceeds the autoregressive length. + if pos >= ar_len: + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i][:, :, :1]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i][:, :1, :]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + atten_mask[:, :, -pos - 1] = 0 + + pos += 1 + return (atten_mask, pos, k_caches, v_caches) + + +def _kv_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + updater=smart_mask_updater, + use_i64_token=False, +): + _, atten_mask, _, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + all_pos = torch.arange(0, max_seq_len, 1, dtype=torch.int32).unsqueeze(0) + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, TiktokenTokenizer): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) if len(token_list) < ar_len else ar_len + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor( + token_list[pos - ar_len : pos], dtype=dtype + ).reshape(1, -1) + tmp_pos = all_pos[:, pos - ar_len : pos] + tmp_atten_mask = atten_mask + if pos < ar_len: + tmp_token_list = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=dtype), + torch.tensor(token_list, dtype=dtype).reshape(1, -1), + ], + dim=1, + ) + tmp_pos = torch.cat( + [ + torch.zeros((1, ar_len - pos), dtype=torch.int32), + all_pos[:, :pos], + ], + dim=1, + ) + tmp_atten_mask = torch.cat( + [ + torch.ones(1, ar_len, max_seq_len - pos) * -255.0, + atten_mask[:, :, -pos:], + ], + dim=-1, + ) + + logits, new_k_caches, new_v_caches = module( + tmp_token_list, + tmp_atten_mask, + tmp_pos, + *k_caches, + *v_caches, + ) + atten_mask, pos, k_caches, v_caches = updater( + ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches + ) + if pos > len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"kv calibration data:\n{tokenizer.decode(token_list)}") + + +def _prefill_calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + max_seq_len=512, + use_i64_token=False, +): + _, atten_mask = example_inputs + + # TODO: change criteria & support batch inputs if necessary + + token_list = [] + # Llama2 tokenizer has no special tokens + if isinstance(tokenizer, SentencePieceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + elif isinstance(tokenizer, TiktokenTokenizer): + token_list = tokenizer.encode( + user_prompts, bos=True, eos=False, allowed_special="all" + ) + elif isinstance(tokenizer, HuggingFaceTokenizer): + token_list = tokenizer.encode(user_prompts, bos=True, eos=False) + else: + raise RuntimeError("Unkown tokenizer") + + pos = len(token_list) + dtype = torch.int64 if use_i64_token else torch.int32 + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_seq_len: + tmp_token_list = torch.tensor(token_list, dtype=dtype).reshape(1, -1) + if pos < max_seq_len: + tmp_token_list = torch.cat( + [ + tmp_token_list, + torch.zeros((1, max_seq_len - pos), dtype=dtype), + ], + dim=1, + ) + results = module( + tmp_token_list, + atten_mask, + ) + if len(results) == 3: + logits, new_k_caches, new_v_caches = results + elif len(results) == 1: + logits = results + token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item()) + pos += 1 + + print(f"prefill calibration data:\n{tokenizer.decode(token_list)}") + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer, + ar_len=1, + max_seq_len=512, + kv_updater=smart_mask_updater, + use_i64_token=False, +): + if len(example_inputs) == 2: + _prefill_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + max_seq_len, + use_i64_token, + ) + elif len(example_inputs) == 5: + _kv_calibrate( + example_inputs, + user_prompts, + module, + tokenizer, + ar_len, + max_seq_len, + updater=kv_updater, + use_i64_token=use_i64_token, + ) + else: + raise RuntimeError("Get wrong inputs") + + +class SingleLlama: + def __init__(self, llama_model, pte_filename) -> None: + super().__init__() + self.llama_model = llama_model + self.passes_job = get_capture_program_passes() + self.dep_table = get_passes_dependency_for_capture_program() + self.quant_attrs = None + self.quant_dtype = None + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + self.pte_filename = pte_filename + if self.llama_meta["get_use_kv_cache"]: + tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( + use_kv_cache=True + ) + self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches) + else: + tokens, atten_mask = self.get_example_inputs(use_kv_cache=False) + self.inputs = (tokens, atten_mask) + self.llama_graph_module = llama_model + self.io_shape = { + # logit output + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_vocab_size"], + ), + } + + def _tag_ios(self, node, fixed_point_type): + if not self.has_quant_io: + return + + # shape of k caches and v caches + kv_cache_shape = { + # single head, kv input + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + # single head, kv output + (self.llama_meta["get_head_dim"], self.llama_meta["get_ar_len"]), + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"]), + } + + atten_mask_shape = { + ( + self.llama_meta["get_max_batch_size"], + self.llama_meta["get_ar_len"], + self.llama_meta["get_max_seq_len"], + ), + } + + freq_shape = { + (self.llama_meta["get_ar_len"], self.llama_meta["get_head_dim"] // 2), + } + + freq_op = { + exir_ops.edge.aten.select.int, + } + quant_io_type = None + + if node.op == "placeholder": + if ( + len(users := list(node.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + elif node.meta["val"].size() in atten_mask_shape: + quant_io_type = fixed_point_type["io_type"] + if is_graph_output(node): + if node.meta["val"].size()[-2:] in kv_cache_shape: + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in self.io_shape: + quant_io_type = fixed_point_type["io_type"] + + # Tag sharding io + if exir_ops.edge.llama.fallback.default in [ + u.target for u in list(node.users.keys()) + ] + [node.target]: + quant_io_type = fixed_point_type["io_type"] + + # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding + if node.target in freq_op and node.meta["val"].size() in freq_shape: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type + + def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + fx_graph_module = None + + with torch.no_grad(): + fx_graph_module = torch.export.export( + self.llama_graph_module, self.inputs, strict=True + ).module() + + if QuantDtype == QuantDtype.use_16a4w_block: + conv_nodes = [ + n for n in fx_graph_module.graph.nodes if "conv" in n.name + ] + block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} + quantizer.set_block_size_map(block_size_map) + + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + + logging.info("Quantizing the model...") + calibrate( + self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), + args.prompt, + fx_graph_module, + tokenizer=tokenizer, + ar_len=self.llama_meta["get_ar_len"], + max_seq_len=self.llama_meta["get_max_seq_len"], + kv_updater=args.kv_updater, + use_i64_token=args.embedding_quantize is not None, + ) + + self.llama_graph_module = convert_pt2e(fx_graph_module) + + def lowering_modules( + self, + work_space, + use_fp16=False, + soc_model=QcomChipset.SM8650, + num_sharding=1, + shared_buffer=False, + verbose=False, + ): + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + with torch.no_grad(): + # backend option + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=num_sharding > 1 + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + shared_buffer=shared_buffer, + ) + skip_node_op_set = {"llama.fallback.default"} + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + self.llama_graph_module, + self.inputs, + compiler_specs, + constant_methods=self.llama_meta, + dep_table=self.dep_table, + passes_job=self.passes_job, + skip_node_op_set=skip_node_op_set, + ) + + for n in edge_prog_mgr.exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in self.io_shape: + self.quant_attrs = output_encoding + + if num_sharding > 1: + update_spill_fill_size(edge_prog_mgr.exported_program()) + + if verbose: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) + + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + def get_example_inputs(self, use_kv_cache=True): + return self.llama_model.get_example_inputs(use_kv_cache) + + def get_quant_attrs(self): + return self.quant_attrs + + +def compile(args, pte_filename, tokenizer): + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + + config = Qwen3Config.from_pretrained(args.model_dir) + + llama_instance_list = [] + use_i64_token = args.embedding_quantize is not None + with torch.device("meta"): + if args.model_mode == "kv": + llama_instance_list.append( + Qwen3ForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + elif args.model_mode == "hybrid": + llama_instance_list.append( + Qwen3ForCausalLM( + config, + ar_len=1, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + llama_instance_list.append( + Qwen3ForCausalLM( + config, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + use_i64_token=use_i64_token, + ) + ) + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + qmodel = GPTQModel.from_quantized(args.model_dir) + state_dict = qmodel.model.state_dict() + qcfg = qmodel.quantize_config + if qcfg.desc_act: + raise RuntimeError( + "desc_act=True is unsupported right now." + ) + qlinear_cls = partial( + TorchQuantLinear, + bits=qcfg.bits, + group_size=qcfg.group_size, + desc_act=qcfg.desc_act, + sym=qcfg.sym, + pack_dtype=qcfg.pack_dtype, + device=qcfg.device, + adapter=qcfg.adapter, + ) + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + layer: Qwen3DecoderLayer + convert_linear_to_qlinear(layer.self_attn, qlinear_cls) + convert_linear_to_qlinear(layer.mlp, qlinear_cls) + + for llama_instance in llama_instance_list: + incompatible_keys = llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + assert len(incompatible_keys.missing_keys) <= 1 and len(incompatible_keys.unexpected_keys) == 0 + if "lm_head.weight" in incompatible_keys.missing_keys: + llama_instance.tie_weights() + end_load_ts = time.time() + logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") + + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + for i in range(len(llama_instance_list)): + llama_instance_list[i] = llama_instance_list[i].to( + dtype_override.to_torch_dtype() + ) + + for llama_instance in llama_instance_list: + for layer in llama_instance.model.layers: + if args.use_tman: + convert_qlinear_to_tman_linear(layer.self_attn) + layer.self_attn.prepare_tman() + convert_qlinear_to_tman_linear(layer.mlp) + else: + convert_qlinear_to_linear(layer.self_attn) + layer.self_attn.prepare_sha() + convert_qlinear_to_linear(layer.mlp) + + use_fp16 = True + fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} + if args.ptq: + use_fp16 = False + fixed_point_type["kv_type"] = torch.uint8 + if args.ptq == "8a8w": + fixed_point_type["io_type"] = torch.uint8 + elif args.ptq in ("16a4w", "16a4w_block"): + fixed_point_type["io_type"] = torch.uint16 + else: + assert args.ptq in [ + "8a8w", + "16a4w", + "16a4w_block", + ], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block." + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + for i in range(len(llama_instance_list)): + if args.embedding_quantize: + llama_instance_list[i] = get_quant_embedding_transform(args)( + llama_instance_list[i] + ) + llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + llama_instance_list[i] = SingleLlama( + llama_instance_list[i].eval(), pte_filename + ) + if args.embedding_quantize: + llama_instance_list[i].passes_job[I64toI32][ + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY + ]["skip_node"] = {"tokens"} + + if args.ptq: + start_quantize_ts = time.time() + custom_annotations = (annotate_matmul_16a8w,) + if args.llama_model == "stories110m": + custom_annotations = custom_annotations + ( + annotate_linear_16a8w_in_affine_layer, + ) + kv_quant_attrs = {} + for i, llama_instance in enumerate(llama_instance_list): + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + tokenizer=tokenizer, + custom_annotations=custom_annotations, + ) + # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode == "hybrid": + output_indices = 0 + for node in llama_instance.llama_graph_module.graph.nodes: + if node.op == "output": + for output in node.args[0]: + kv_quant_attrs[output_indices] = output.args[1:] + output_indices += 1 + break + custom_annotations = custom_annotations + ( + partial( + annotate_prefill_kv_output, + kv_quant_attrs=kv_quant_attrs, + ), + ) + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(llama_instance._tag_ios, fixed_point_type=fixed_point_type) + end_quantize_ts = time.time() + logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") + + start_lowering_ts = time.time() + quant_attrs = None + if args.num_sharding > 1: + for llama_instance in llama_instance_list: + SplitGraph, setting = model_sharding.get_split_graph_pass( + llama_instance.llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + llama_instance.passes_job[SplitGraph] = setting + llama_instance.dep_table[SplitGraph] = [FoldQDQ] + llama_instance.dep_table[TagQuantIO] = [SplitGraph] + + if args.model_mode in ["kv"]: + llama_instance_list[0].lowering_modules( + args.artifact, + use_fp16=use_fp16, + soc_model=get_soc_to_chipset_map()[args.model], + num_sharding=args.num_sharding, + shared_buffer=args.shared_buffer, + ) + quant_attrs = llama_instance_list[0].get_quant_attrs() + elif args.model_mode == "hybrid": + sample_inputs_list = [ + llama_instace.inputs for llama_instace in llama_instance_list + ] + backend_options = generate_htp_compiler_spec( + use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 1 + ) + graph_names = ["kv_forward", "prefill_forward"] + compiler_specs = [ + generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.model], + backend_options=backend_options, + shared_buffer=args.shared_buffer, + multiple_graphs=True, + weight_sharing=not args.enable_x86_64, # x86 emulator does not support weight sharing + graph_name=graph_name, + ) + for graph_name in graph_names + ] + skip_node_op_set = {"llama.fallback.default"} + edge_prog_mgrs = [ + to_edge_transform_and_lower_to_qnn( + llama_instance.llama_graph_module, + sample_input, + compile_spec, + dep_table=llama_instance.dep_table, + passes_job=llama_instance.passes_job, + skip_node_op_set=skip_node_op_set, + ) + for llama_instance, sample_input, compile_spec in zip( + llama_instance_list, sample_inputs_list, compiler_specs + ) + ] + for n in edge_prog_mgrs[0].exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in llama_instance_list[0].io_shape: + quant_attrs = output_encoding + + if args.num_sharding > 1: + max_sf_size = update_spill_fill_size( + [edge_prog_mgr.exported_program() for edge_prog_mgr in edge_prog_mgrs] + ) + qnn_executorch_options = flatbuffer_to_option(compiler_specs[0][0].value) + qnn_executorch_options.backend_options.htp_options.max_sf_buf_size = ( + max_sf_size + ) + compiler_specs[0][0].value = option_to_flatbuffer(qnn_executorch_options) + + if args.verbose: + for edge_prog_mgr in edge_prog_mgrs: + print_delegation_info(edge_prog_mgr.exported_program().graph_module) + + executorch_config = ExecutorchBackendConfig( + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + + bundle_progs_list = [] + lower_module_dict = {name: [] for name in graph_names} + call_delegate_inputs_dict = {name: [] for name in graph_names} + call_delegate_node_name_dict = {name: [] for name in graph_names} + outputs_dict = {name: [] for name in graph_names} + input_nodes_dict = {name: [] for name in graph_names} + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + for node in prog.exported_program().graph_module.graph.nodes: + if ( + node.op == "call_function" + and "executorch_call_delegate" in node.name + ): + call_delegate_node_name_dict[graph_name].append(node.name) + call_delegate_inputs_list = [] + for arg in node.args: + if arg.op == "call_function": + if ( + arg.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ): + call_delegate_inputs_list.append((arg.name, None)) + else: + while "getitem" not in arg.name: + arg = arg.args[0] + call_delegate_inputs_list.append( + (arg.args[0].name, arg.args[1]) + ) + elif arg.op == "placeholder": + call_delegate_inputs_list.append((arg.name, None)) + # No extra needs to do for get_attr node + call_delegate_inputs_dict[graph_name].append( + call_delegate_inputs_list + ) + elif node.op == "output": + for arg in node.args[0]: + outputs_dict[graph_name].append((arg.args[0].name, arg.args[1])) + for num in range(args.num_sharding - 1, -1, -1): + processed_bytes = [] + for prog, graph_name in zip(edge_prog_mgrs, graph_names): + processed_bytes.append( + getattr( + prog.exported_program().graph_module, f"lowered_module_{num}" + ).processed_bytes + ) + call_delegate_node = [ + list(node.users.keys())[0] + for node in prog.exported_program().graph_module.graph.nodes + if node.op == "get_attr" and node.name == f"lowered_module_{num}" + ] + input_nodes_dict[graph_name] = [ + node + for node in call_delegate_node[0].args + if node.op == "placeholder" + or node.target + == exir_ops.edge.quantized_decomposed.embedding_4bit.dtype + ] + prog_mgr, bundle_progs = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + processed_bytes=processed_bytes, + input_nodes_dict=input_nodes_dict, + backend_config=executorch_config, + constant_methods=llama_instance_list[0].llama_meta, # kv method meta + ) + bundle_progs_list.append(bundle_progs) + for graph_name in graph_names: + lower_module_dict[graph_name].append( + prog_mgr.exported_program(graph_name).graph_module._modules.get( + "lowered_module_0" + ) + ) + exec_prog = generate_composite_llama_program( + llama_model=llama_instance_list[1].llama_model, + graph_names=graph_names, + sample_inputs_list=sample_inputs_list, + lower_module_dict=lower_module_dict, + call_delegate_node_name_dict=call_delegate_node_name_dict, + call_delegate_inputs_dict=call_delegate_inputs_dict, + outputs_dict=outputs_dict, + embedding_quantize=args.embedding_quantize, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + exec_prog.write_to_file(file) + + end_lowering_ts = time.time() + logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") + return quant_attrs + + +def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + + if args.model_mode == "kv": + eval_mode = 0 + elif args.model_mode == "hybrid": + eval_mode = 1 + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + pte_path = ( + f"{pre_gen_pte}/{pte_filename}.pte" + if pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + with open(f"{args.artifact}/outputs/outputs.txt", "r") as f: + outputs.append(f.read()) + + seq_len = args.max_seq_len + runner_args = " ".join( + [ + f'--prompt "{args.prompt}"', + f"--eval_mode {eval_mode}", + f"--temperature {args.temperature}", + f"--system_prompt '{args.system_prompt}'", + f"--logits_scale {quant_attrs['scale']}", + f"--logits_offset {quant_attrs['zero_point']}", + ] + ) + + runner_cmd = "" + performance_output_path = "outputs/inference_speed.txt" + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. Check only the first few tokens. + seq_len = min(seq_len, 16) + + if args.kv_updater == smart_mask_updater: + logging.warning( + "x86 only support ShiftPointer, overwrite kv_updater to ShiftPointer" + ) + + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + f"--tokenizer_path {runtime_tokenizer_path}", + f"--model_path {pte_path}", + f"--seq_len {seq_len}", + f"--output_path {args.artifact}/outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", + f"--kv_updater ShiftPointer", + runner_args, + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + f"./qnn_llama_runner", + f"--tokenizer_path {os.path.basename(runtime_tokenizer_path)}", + f"--model_path {pte_filename}.pte", + f"--seq_len {seq_len}", + "--output_path outputs/outputs.txt", + f"--performance_output_path {performance_output_path}", + f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", + runner_args, + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=[], input_list="", files=[runtime_tokenizer_path]) + adb.execute(custom_runner_cmd=runner_cmd) + + adb.pull(output_path=args.artifact, callback=post_process) + if args.ip and args.port != -1: + inference_speed = 0 + with open(f"{args.artifact}/{performance_output_path}", "r") as f: + inference_speed = float(f.read()) + + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "result": outputs, + "pte_size": pte_size, + "inference_speed": inference_speed, + } + ) + ) + else: + for idx, output in enumerate(outputs): + logging.info(f"Results[{idx}]:\n{output}") + + +def _build_parser(): + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./llama_qnn", + default="./llama_qnn", + type=str, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.", + type=str, + ) + + parser.add_argument( + "--llama_model", + choices=["stories110m", "llama3_2", "qwen3"], + help="The Llama model to export. Current available options are: [stories110m, llama3_2, qwen3]", + required=True, + ) + + parser.add_argument( + "--model_dir", + help="Pass llama checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="For Llama2. Pass Llama2 tokenizer binary.", + required=False, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama.", + required=True, + type=str, + ) + + parser.add_argument( + "--system_prompt", + help="For Llama3. Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", + default="", + type=str, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated llama in the given directory.", + type=str, + ) + + parser.add_argument( + "--num_sharding", + type=int, + default=1, + help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", + ) + + parser.add_argument( + "--model_mode", + help="Export and inference kv mode or hybrid mode", + default="kv", + choices=["kv", "hybrid"], + type=str, + ) + + parser.add_argument( + "--max_seq_len", + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + default=512, + type=int, + ) + + parser.add_argument( + "--prefill_ar_len", + help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.", + default=32, + type=int, + ) + + parser.add_argument( + "--kv_updater", + help="Choose how to update kv cache during runtime", + choices=["smart_mask", "shift_pointer"], + default="smart_mask", + type=str, + ) + + parser.add_argument( + "-E", + "--embedding-quantize", + default=None, + type=str, + help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", + ) + + parser.add_argument( + "--use_tman", + action="store_true", + help="Use TMANLinear instead of QNNConv2d.", + ) + + parser.add_argument("-v", "--verbose", action="store_true") + + return parser + + +def export_llama(args) -> None: + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + if args.model_mode == "kv": + pte_filename = "kv_llama_qnn" + elif args.model_mode == "hybrid": + assert ( + args.max_seq_len >= args.prefill_ar_len + ), "Please ensure max_seq_len is >= prefill_ar_len" + pte_filename = "hybrid_llama_qnn" + else: + raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + tokenizer = get_tokenizer(args.tokenizer_model) + runtime_tokenizer_path = "" + if args.llama_model == "stories110m": + assert isinstance( + tokenizer, SentencePieceTokenizer + ), f"Wrong tokenizer provided for stories110m." + assert ( + args.tokenizer_bin is not None + ), "Please provide tokenizer_bin for stories110m." + runtime_tokenizer_path = args.tokenizer_bin + elif args.llama_model == "llama3_2": + assert isinstance( + tokenizer, TiktokenTokenizer + ), f"Wrong tokenizer provided for llama3_2." + runtime_tokenizer_path = args.tokenizer_model + elif args.llama_model == "qwen3": + assert isinstance( + tokenizer, HuggingFaceTokenizer + ), f"Wrong tokenizer provided for qwen3." + runtime_tokenizer_path = args.tokenizer_model + else: + raise RuntimeError(f"Unknown llama_model: {args.llama_model}.") + + if args.kv_updater == "smart_mask": + args.shared_buffer = True + args.kv_updater = smart_mask_updater + elif args.kv_updater == "shift_pointer": + args.kv_updater = shift_pointer_updater + else: + exit(f"Using an unkown kv update {args.kv_updater}") + + if args.pre_gen_pte: + quant_attrs = json.load( + open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt") + ) + inference( + args, quant_attrs, pte_filename, runtime_tokenizer_path, args.pre_gen_pte + ) + exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") + + if args.compile_only: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + + if args.ip and args.port != -1: + pte_path = f"{args.artifact}/{pte_filename}.pte" + pte_size = os.path.getsize(pte_path) + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "pte_size": pte_size, + } + ) + ) + exit(f"Finish compile_only and save to {args.artifact}") + + try: + quant_attrs = compile(args, pte_filename, tokenizer) + if quant_attrs: + logging.info( + f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}" + ) + json.dump( + { + "scale": quant_attrs["scale"], + "zero_point": quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"), + ) + else: + logging.warning("Quant attributes of the logit is None.") + inference(args, quant_attrs, pte_filename, runtime_tokenizer_path) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) + + +def main(): + parser = _build_parser() + args = parser.parse_args() + args.prompt = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(args.prompt) + export_llama(args) + + +# flake8: noqa: C901 +if __name__ == "__main__": + main() From c6fd30c29baedcb72593643edba7674eafe5db1e Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Tue, 13 May 2025 04:16:03 +0000 Subject: [PATCH 06/11] LlamaDemo: Fix progress bug when file size larger than 2GB --- .../executorchllamademo/SettingsActivity.java | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java index fa0066188b6..ef452decf43 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -350,6 +350,24 @@ private static class ModelInfo { "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn.pte?download=true", "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn_quant_attrs.txt?download=true" ), + new ModelInfo( + "llama-3.1-8B-Instruct", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller.pte?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_quant_attrs.txt?download=true" + ), + new ModelInfo( + "llama-3.1-8B-Instruct-LongContext", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller_ctx1024.pte?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_ctx1024_quant_attrs.txt?download=true" + ), + new ModelInfo( + "Qwen-3-8B", + "https://huggingface.co/Qwen/Qwen3-8B/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller.pte?download=true", + "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller_quant_attrs.txt?download=true" + ) }; private final String mOpPackageUrl = "https://huggingface.co/JY-W/test_model/resolve/main/libQnnTMANOpPackage.so?download=true"; @@ -382,7 +400,7 @@ private void downloadFileFromUrl(String fileUrl, String fileName, boolean overwr return; } - int fileLength = connection.getContentLength(); + long fileLength = connection.getContentLengthLong(); InputStream input = connection.getInputStream(); try (FileOutputStream output = new FileOutputStream(outputFile)) { From b81020c60a6a48d679c71d8057a3f90bbe6d3108 Mon Sep 17 00:00:00 2001 From: Hansong <107070759+kirklandsign@users.noreply.github.com> Date: Mon, 12 May 2025 14:23:22 -0700 Subject: [PATCH 07/11] Android Qwen thinking mode prompt support (#10668) Use different prompts according to mode --- .../executorchllamademo/MainActivity.java | 12 +++++--- .../executorchllamademo/PromptFormat.java | 29 ++++++++++++++----- .../executorchllamademo/SettingsActivity.java | 7 +++-- .../executorchllamademo/SettingsFields.java | 14 +++++---- 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 87e9436b581..37268202b69 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -692,7 +692,10 @@ private String getConversationHistory() { prevPromptID = currentPromptID; } if (conversation.getIsSent()) { - format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()); + format = + format + .replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()) + .replace(PromptFormat.THINKING_MODE_PLACEHOLDER, ""); } else { format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); } @@ -704,12 +707,12 @@ private String getConversationHistory() { private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { if (conversationHistory.isEmpty()) { - return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); } return mCurrentSettingsFields.getFormattedSystemPrompt() + conversationHistory - + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt); + + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt, mThinkMode); } private void onModelRunStarted() { @@ -738,7 +741,8 @@ private void onModelRunStopped() { if (ModelUtils.getModelCategory( mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) == ModelUtils.VISION_MODEL) { - finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + finalPrompt = + mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode); } else { finalPrompt = getTotalFormattedPrompt(getConversationHistory(), rawPrompt); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 76c4d5f3b16..5f8ecdd8042 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -13,6 +13,7 @@ public class PromptFormat { public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; + public static final String THINKING_MODE_PLACEHOLDER = "{{ thinking_mode }}"; public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; public static String getSystemPromptTemplate(ModelType modelType) { @@ -32,7 +33,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { } } - public static String getUserPromptTemplate(ModelType modelType) { + public static String getUserPromptTemplate(ModelType modelType, boolean thinkingMode) { switch (modelType) { case LLAMA_3: case LLAMA_3_1: @@ -43,15 +44,13 @@ public static String getUserPromptTemplate(ModelType modelType) { + "<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>"; - case LLAVA_1_5: case QWEN_3: return "<|im_start|>user\n" + USER_PLACEHOLDER - + "<|im_end|>\n" + + "\n<|im_end|>\n" + "<|im_start|>assistant\n" - + "\n" - + "\n" - + "\n\n\n"; + + THINKING_MODE_PLACEHOLDER; + case LLAVA_1_5: default: return USER_PLACEHOLDER; } @@ -62,9 +61,14 @@ public static String getConversationFormat(ModelType modelType) { case LLAMA_3: case LLAMA_3_1: case LLAMA_3_2: - return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; + return getUserPromptTemplate(modelType, false) + + "\n" + + ASSISTANT_PLACEHOLDER + + "<|eot_id|>"; case LLAVA_1_5: return USER_PLACEHOLDER + " ASSISTANT:"; + case QWEN_3: + return getUserPromptTemplate(modelType, false) + "<|im_end|>\n"; default: return USER_PLACEHOLDER; } @@ -86,13 +90,22 @@ public static String getStopToken(ModelType modelType) { } } + public static String getThinkingModeToken(ModelType modelType, boolean thinkingMode) { + switch (modelType) { + case QWEN_3: + return thinkingMode ? "" : "\n\n\n\n\n"; + default: + return ""; + } + } + public static String getLlavaPresetPrompt() { return "A chat between a curious human and an artificial intelligence assistant. The assistant" + " gives helpful, detailed, and polite answers to the human's questions. USER: "; } public static String getFormattedLlamaGuardPrompt(String userPrompt) { - return getUserPromptTemplate(ModelType.LLAMA_GUARD_3) + return getUserPromptTemplate(ModelType.LLAMA_GUARD_3, false) .replace( USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java index ef452decf43..f5b6175c746 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -276,7 +276,8 @@ public void afterTextChanged(Editable s) { new DialogInterface.OnClickListener() { public void onClick(DialogInterface dialog, int whichButton) { // Clear the messageAdapter and sharedPreference - mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + mUserPromptEditText.setText( + PromptFormat.getUserPromptTemplate(mModelType, false)); } }) .setNegativeButton(android.R.string.no, null) @@ -299,7 +300,7 @@ private void showInvalidPromptDialog() { .setPositiveButton( android.R.string.yes, (dialog, whichButton) -> { - mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); }) .setNegativeButton(android.R.string.no, null) .show(); @@ -533,7 +534,7 @@ private void setupModelTypeSelectorDialog() { (dialog, item) -> { mModelTypeTextView.setText(modelTypes[item]); mModelType = ModelType.valueOf(modelTypes[item]); - mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType)); + mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType, false)); dialog.dismiss(); }); diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java index 3adadf574da..94036f43947 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java @@ -38,8 +38,8 @@ public String getUserPrompt() { return userPrompt; } - public String getFormattedSystemAndUserPrompt(String prompt) { - return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt); + public String getFormattedSystemAndUserPrompt(String prompt, boolean thinkingMode) { + return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt, thinkingMode); } public String getFormattedSystemPrompt() { @@ -47,8 +47,12 @@ public String getFormattedSystemPrompt() { .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); } - public String getFormattedUserPrompt(String prompt) { - return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt); + public String getFormattedUserPrompt(String prompt, boolean thinkingMode) { + return userPrompt + .replace(PromptFormat.USER_PLACEHOLDER, prompt) + .replace( + PromptFormat.THINKING_MODE_PLACEHOLDER, + PromptFormat.getThinkingModeToken(modelType, thinkingMode)); } public boolean getIsClearChatHistory() { @@ -77,7 +81,7 @@ public SettingsFields() { tokenizerFilePath = ""; temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; systemPrompt = ""; - userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL); + userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL, false); isClearChatHistory = false; isLoadModel = false; modelType = DEFAULT_MODEL; From 9f08e74fbef53dc7f4e142d2c0e3b4366353eeba Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Tue, 27 May 2025 14:45:47 +0000 Subject: [PATCH 08/11] LlamaDemo: Update model urls and tokenizers to support Qwen3 eos_token --- .gitmodules | 2 +- .../executorchllamademo/PromptFormat.java | 8 +- .../executorchllamademo/SettingsActivity.java | 159 +++++++++++++----- .../oss_scripts/llama/runner/runner.cpp | 22 +-- extension/android/jni/jni_layer_llama.cpp | 3 +- extension/llm/tokenizers | 2 +- 6 files changed, 137 insertions(+), 59 deletions(-) diff --git a/.gitmodules b/.gitmodules index 8b51b55a13a..6df65508856 100644 --- a/.gitmodules +++ b/.gitmodules @@ -30,7 +30,7 @@ url = https://github.com/Maratyszcza/pthreadpool.git [submodule "extension/llm/tokenizers"] path = extension/llm/tokenizers - url = https://github.com/pytorch-labs/tokenizers.git + url = https://github.com/kaleid-liner/tokenizers.git [submodule "kernels/optimized/third-party/eigen"] path = kernels/optimized/third-party/eigen url = https://gitlab.com/libeigen/eigen.git diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 5f8ecdd8042..b3e0aa9b0ac 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -27,7 +27,7 @@ public static String getSystemPromptTemplate(ModelType modelType) { case LLAVA_1_5: return "USER: "; case QWEN_3: - return "<|im_start|>system\n" + "You are a helpful assistant.\n" + "<|im_end|>\n"; + return ""; default: return SYSTEM_PLACEHOLDER; } @@ -47,7 +47,7 @@ public static String getUserPromptTemplate(ModelType modelType, boolean thinking case QWEN_3: return "<|im_start|>user\n" + USER_PLACEHOLDER - + "\n<|im_end|>\n" + + "<|im_end|>\n" + "<|im_start|>assistant\n" + THINKING_MODE_PLACEHOLDER; case LLAVA_1_5: @@ -84,7 +84,7 @@ public static String getStopToken(ModelType modelType) { case LLAVA_1_5: return ""; case QWEN_3: - return "<|endoftext|>"; + return "<|im_end|>"; default: return ""; } @@ -93,7 +93,7 @@ public static String getStopToken(ModelType modelType) { public static String getThinkingModeToken(ModelType modelType, boolean thinkingMode) { switch (modelType) { case QWEN_3: - return thinkingMode ? "" : "\n\n\n\n\n"; + return thinkingMode ? "" : "\n\n\n\n"; default: return ""; } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java index f5b6175c746..645cff63d14 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -36,6 +36,7 @@ import java.io.FileInputStream; import java.io.InputStream; import java.net.URL; +import android.util.Log; public class SettingsActivity extends AppCompatActivity { @@ -332,12 +333,14 @@ private void setupBackendSelectorDialog() { private static class ModelInfo { String modelName; String tokenizerUrl; + String tokenizerConfigUrl; String modelUrl; String quantAttrsUrl; - ModelInfo(String modelName, String tokenizerUrl, String modelUrl, String quantAttrsUrl) { + ModelInfo(String modelName, String tokenizerUrl, String tokenizerConfigUrl, String modelUrl, String quantAttrsUrl) { this.modelName = modelName; this.tokenizerUrl = tokenizerUrl; + this.tokenizerConfigUrl = tokenizerConfigUrl; this.modelUrl = modelUrl; this.quantAttrsUrl = quantAttrsUrl; } @@ -345,35 +348,105 @@ private static class ModelInfo { // Construct the model info array private final ModelInfo[] modelInfoArray = new ModelInfo[] { + // new ModelInfo( + // "bitnet-b1.58-2B-4T", + // "https://huggingface.co/JY-W/test_model/resolve/main/tokenizer.json?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn.pte?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn_quant_attrs.txt?download=true" + // ), + // new ModelInfo( + // "llama-3.1-8B-Instruct", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller.pte?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_quant_attrs.txt?download=true" + // ), + // new ModelInfo( + // "llama-3.1-8B-Instruct-LongContext", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller_ctx1024.pte?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_ctx1024_quant_attrs.txt?download=true" + // ), + // new ModelInfo( + // "Qwen-3-8B", + // "https://huggingface.co/Qwen/Qwen3-8B/resolve/main/tokenizer.json?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller.pte?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller_quant_attrs.txt?download=true" + // ), + // new ModelInfo( + // "Qwen-3-8B-LongContext", + // "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_tokenizer.json?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller_ctx1024.pte?download=true", + // "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller_ctx1024_quant_attrs.txt?download=true" + // ) new ModelInfo( - "bitnet-b1.58-2B-4T", - "https://huggingface.co/JY-W/test_model/resolve/main/tokenizer.json?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn.pte?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/kv_llama_qnn_quant_attrs.txt?download=true" + "BitNet-b1.58-2B-4T-Fast", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/model.pte?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/quant_attrs.txt?download=true" ), new ModelInfo( - "llama-3.1-8B-Instruct", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller.pte?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_quant_attrs.txt?download=true" + "BitNet-b1.58-2B-4T", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/model_1k.pte?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/quant_attrs_1k.txt?download=true" ), new ModelInfo( - "llama-3.1-8B-Instruct-LongContext", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_tokenizer.json?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_instruct_bitdistiller_ctx1024.pte?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/llama_3_1_8b_bitdistiller_ctx1024_quant_attrs.txt?download=true" + "BitNet-b1.58-2B-4T-LongContext", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/model_4k.pte?download=true", + "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/quant_attrs_4k.txt?download=true" + ), + new ModelInfo( + "Llama-3.1-8B-Instruct-Fast", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/model.pte?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/quant_attrs.txt?download=true" + ), + new ModelInfo( + "Llama-3.1-8B-Instruct", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/model_1k.pte?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/quant_attrs_1k.txt?download=true" + ), + new ModelInfo( + "Llama-3.1-8B-Instruct-LongContext", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/model_4k.pte?download=true", + "https://huggingface.co/BitDistiller/Llama-3.1-8B-Instruct-w2g64-pte/resolve/main/quant_attrs_4k.txt?download=true" + ), + new ModelInfo( + "Qwen-3-8B-Fast", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/model.pte?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/quant_attrs.txt?download=true" ), new ModelInfo( "Qwen-3-8B", - "https://huggingface.co/Qwen/Qwen3-8B/resolve/main/tokenizer.json?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller.pte?download=true", - "https://huggingface.co/JY-W/test_model/resolve/main/qwen3_8b_bitdistiller_quant_attrs.txt?download=true" + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/model_1k.pte?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/quant_attrs_1k.pte?download=true" + ), + new ModelInfo( + "Qwen-3-8B-LongContext", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/model_4k.pte?download=true", + "https://huggingface.co/BitDistiller/Qwen-8B-w2g64-pte/resolve/main/quant_attrs_4k.pte?download=true" ) }; - private final String mOpPackageUrl = "https://huggingface.co/JY-W/test_model/resolve/main/libQnnTMANOpPackage.so?download=true"; + // private final String mOpPackageUrl = "https://huggingface.co/JY-W/test_model/resolve/main/libQnnTMANOpPackage.so?download=true"; + private final String mOpPackageUrl = "https://huggingface.co/BitDistiller/BitNet-2B-4T-pte/resolve/main/libQnnTMANOpPackage.so?download=true"; - private void downloadFileFromUrl(String fileUrl, String fileName, boolean overwrite) { + private void downloadFileFromUrl(String fileUrl, String fileName, File outputFile, boolean overwrite) { ProgressDialog progressDialog = new ProgressDialog(this); progressDialog.setTitle("Downloading " + fileName + "..."); progressDialog.setMessage("Please wait..."); @@ -386,8 +459,6 @@ private void downloadFileFromUrl(String fileUrl, String fileName, boolean overwr new Thread(() -> { try { - File outputDir = getExternalFilesDir(null); - File outputFile = new File(outputDir, fileName); if (!outputFile.exists() || overwrite) { URL url = new URL(fileUrl); HttpURLConnection connection = (HttpURLConnection) url.openConnection(); @@ -437,44 +508,50 @@ private void downloadFileFromUrl(String fileUrl, String fileName, boolean overwr } private void downloadModel(ModelInfo modelInfo) { - String modelFileName = modelInfo.modelName + ".pte"; - String tokenizerFileName = modelInfo.modelName + "_tokenizer.json"; - String quantAttrsFileName = modelInfo.modelName + "_quant_attrs.txt"; + String modelFileName = "model.pte"; + String tokenizerFileName = "tokenizer.json"; + String tokenizerConfigFileName = "tokenizer_config.json"; + String quantAttrsFileName = "quant_attrs.txt"; String opPackageFileName = "libQnnTMANOpPackage.so"; - File modelFile = new File(getExternalFilesDir(null), modelFileName); - File tokenizerFile = new File(getExternalFilesDir(null), tokenizerFileName); - File quantAttrsFile = new File(getExternalFilesDir(null), quantAttrsFileName); - File opPackageFile = new File(getExternalFilesDir(null), opPackageFileName); + File modelDir = getExternalFilesDir(modelInfo.modelName); + File modelFile = new File(modelDir, modelFileName); + File tokenizerFile = new File(modelDir, tokenizerFileName); + File tokenizerConfigFile = new File(modelDir, tokenizerConfigFileName); + File quantAttrsFile = new File(modelDir, quantAttrsFileName); + File opPackageFile = new File(modelDir, opPackageFileName); - if (modelFile.exists() || tokenizerFile.exists() || quantAttrsFile.exists() || opPackageFile.exists()) { + if (modelFile.exists() || tokenizerFile.exists() || tokenizerConfigFile.exists() || quantAttrsFile.exists() || opPackageFile.exists()) { runOnUiThread(() -> { new AlertDialog.Builder(this) .setTitle("Overwrite Existing Files") .setMessage("Some files for this model already exist. Do you want to overwrite them?") .setPositiveButton("Yes", (dialog, which) -> { - downloadFileFromUrl(modelInfo.modelUrl, modelFileName, true); - downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, true); - downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, true); - downloadFileFromUrl(mOpPackageUrl, opPackageFileName, true); + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, modelFile, true); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, tokenizerFile, true); + downloadFileFromUrl(modelInfo.tokenizerConfigUrl, tokenizerConfigFileName, tokenizerConfigFile, true); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, quantAttrsFile, true); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, opPackageFile, true); }) .setNegativeButton("No", (dialog, which) -> { - downloadFileFromUrl(modelInfo.modelUrl, modelFileName, false); - downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, false); - downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, false); - downloadFileFromUrl(mOpPackageUrl, opPackageFileName, false); + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, modelFile, false); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, tokenizerFile, false); + downloadFileFromUrl(modelInfo.tokenizerConfigUrl, tokenizerConfigFileName, tokenizerConfigFile, false); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, quantAttrsFile, false); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, opPackageFile, false); }) .show(); }); } else { - downloadFileFromUrl(modelInfo.modelUrl, modelFileName, true); - downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, true); - downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, true); - downloadFileFromUrl(mOpPackageUrl, opPackageFileName, true); + downloadFileFromUrl(modelInfo.modelUrl, modelFileName, modelFile, true); + downloadFileFromUrl(modelInfo.tokenizerUrl, tokenizerFileName, tokenizerFile, true); + downloadFileFromUrl(modelInfo.tokenizerConfigUrl, tokenizerConfigFileName, tokenizerConfigFile, true); + downloadFileFromUrl(modelInfo.quantAttrsUrl, quantAttrsFileName, quantAttrsFile, true); + downloadFileFromUrl(mOpPackageUrl, opPackageFileName, opPackageFile, true); } mModelFilePath = modelFile.getAbsolutePath(); mModelTextView.setText(getFilenameFromPath(mModelFilePath)); - mTokenizerFilePath = tokenizerFile.getAbsolutePath(); + mTokenizerFilePath = modelDir.getAbsolutePath(); mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath)); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 40e34fbd1ac..a828036a695 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -192,35 +192,35 @@ Error Runner::load() { } // llama3 tokenizer - tokenizer_ = example::get_tiktoken_for_llama(); + tokenizer_ = std::make_unique(); auto err = tokenizer_->load(tokenizer_path_); if (err != tokenizers::Error::Ok) { ET_LOG( Info, - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", + "Failed to load %s as a HF tokenizer artifact, trying Tiktoken", tokenizer_path_.c_str()); tokenizer_.reset(); - // llama2 tokenizer - tokenizer_ = std::make_unique(); - err = tokenizer_->load(tokenizer_path_); - llama_version_ = LlamaVersion::kLlama2; + tokenizer_ = example::get_tiktoken_for_llama(); + auto err = tokenizer_->load(tokenizer_path_); if (err != tokenizers::Error::Ok) { ET_LOG( Info, - "Failed to load %s as a llama2.c tokenizer artifact", + "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", tokenizer_path_.c_str()); tokenizer_.reset(); - tokenizer_ = std::make_unique(); + // llama2 tokenizer + tokenizer_ = std::make_unique(); err = tokenizer_->load(tokenizer_path_); ET_CHECK_MSG( err == tokenizers::Error::Ok, "failed to load tokenizer %s", tokenizer_path_.c_str()); + llama_version_ = LlamaVersion::kLlama2; + } else { eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); llama_version_ = LlamaVersion::kLlama3; } } else { - eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); llama_version_ = LlamaVersion::kLlama3; } bos_id_ = tokenizer_->bos_tok(); @@ -517,9 +517,11 @@ Error Runner::generate( ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); int32_t seq_len = config.seq_len; + // The passed in seq_len could be too short. Use context_len_ for now + seq_len = context_len_; seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; tokenizers::Result> encode_res = - tokenizer_->encode(prompt, n_bos_, 0); + tokenizer_->encode(prompt, 0, 0); // set to 0 to avoid bos at the beginning ET_CHECK_TK_OK_OR_RETURN_ERROR( encode_res.error(), "failed to encode prompt %s", prompt.c_str()); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index dde0be2ba20..577394c0bf1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -180,8 +180,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { std::filesystem::path op_package_fs = model_path_fs.parent_path() / "libQnnTMANOpPackage.so"; std::string op_package_str = op_package_fs.string() + ":TMANOpPackageInterfaceProvider:HTP"; - std::string quant_attrs_str = model_path_fs.stem().string() + "_quant_attrs.txt"; - std::filesystem::path quant_attrs_fs = model_path_fs.parent_path() / quant_attrs_str; + std::filesystem::path quant_attrs_fs = model_path_fs.parent_path() / "quant_attrs.txt"; std::ifstream quant_attrs_file(quant_attrs_fs.string()); if (!quant_attrs_file.is_open()) { diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index 9ceef562d5c..c1891ac1a48 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit 9ceef562d5c941eb6aea5476c768d0419962bc0c +Subproject commit c1891ac1a48287d68303f60fc6a193aafe5f0c4e From 817bd29dfbea4979a27f2f539c753bd46757ea7a Mon Sep 17 00:00:00 2001 From: Jianyu Wei Date: Tue, 3 Jun 2025 09:30:28 +0000 Subject: [PATCH 09/11] [Fix] TMANOpPackage: Place weights in TCM to leverage DMA This is a miss-imported commit. The released model and libQnnTMANOpPackage.so should be built with this updates already. --- .../op_packages/TMANOpPackage/src/ops/TMANLinear.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp index d33f5494f9b..d8e3dfcbbea 100644 --- a/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp +++ b/backends/qualcomm/runtime/op_packages/TMANOpPackage/src/ops/TMANLinear.cpp @@ -62,9 +62,9 @@ DEF_PACKAGE_OP((tmanlinearImpl), "TMANLinear") DEF_TENSOR_PROPERTIES( Op("TMANLinear", "l", "qweight", "scales", "group_size", "bits", "symmetric"), - Flat("*", "qweight", "scales"), - MainMemory("qweight", "scales", "group_size", "bits", "symmetric"), - Tcm("*", "l")) + Flat("*", "l", "qweight", "scales"), + MainMemory("group_size", "bits", "symmetric"), + Tcm("*", "l", "qweight", "scales")) #define SIZE_OF(WEIGHT) MUL(ELEMENTSIZE_OF(WEIGHT), DIM_OF(WEIGHT, 0), DIM_OF(WEIGHT, 1), DIM_OF(WEIGHT, 2), DIM_OF(WEIGHT, 3)) From 110796c28f6a0bd1ae1907dd702124e237c7bcf5 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 14 Nov 2025 14:24:05 -0800 Subject: [PATCH 10/11] update tokenizers --- .gitmodules | 3 +++ CMakeLists.txt | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 4f2176e5399..1f202d4fdec 100644 --- a/.gitmodules +++ b/.gitmodules @@ -25,6 +25,9 @@ [submodule "backends/xnnpack/third-party/pthreadpool"] path = backends/xnnpack/third-party/pthreadpool url = https://github.com/google/pthreadpool.git +[submodule "extension/llm/tokenizers"] + path = extension/llm/tokenizers + url = https://github.com/meta-pytorch/tokenizers.git [submodule "kernels/optimized/third-party/eigen"] path = kernels/optimized/third-party/eigen url = https://gitlab.com/libeigen/eigen.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 51573d276b3..10c15477d08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -658,7 +658,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM) ) set(CMAKE_POSITION_INDEPENDENT_CODE ON) endif() - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizers) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) set(CMAKE_POSITION_INDEPENDENT_CODE ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG} From ca919cc921d561cd284a183a0adf039acd18ada3 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 14 Nov 2025 16:03:23 -0800 Subject: [PATCH 11/11] fix tokenizers --- CMakeLists.txt | 1 + backends/qualcomm/runtime/QnnManager.cpp | 69 +------------------ .../runtime/backends/QnnBackendFactory.cpp | 8 +-- .../runtime/backends/QnnBackendFactory.h | 3 +- backends/qualcomm/scripts/build.sh | 3 - extension/llm/tokenizers | 1 + 6 files changed, 9 insertions(+), 76 deletions(-) create mode 160000 extension/llm/tokenizers diff --git a/CMakeLists.txt b/CMakeLists.txt index 10c15477d08..51573d276b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -658,6 +658,7 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM) ) set(CMAKE_POSITION_INDEPENDENT_CODE ON) endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizers) if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) set(CMAKE_POSITION_INDEPENDENT_CODE ${ORIGINAL_CMAKE_POSITION_INDEPENDENT_CODE_FLAG} diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index de1e10cc44f..5e3220f25d9 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -25,48 +25,6 @@ namespace executorch { namespace backends { namespace qnn { -namespace { -// TODO: [ERROR] [Qnn ExecuTorch]: tcm_migration.cc:174:ERROR:Memory properties specified twice for operator ::TMANLinear -// The root cause of this error is that when QNN backend is freed, the memory properties of custom ops are not cleared, -// which will cause the error when the QNN backend is loaded and custom ops are registered again. -// This is a bug in QNN SDK, related to DEF_TENSOR_PROPERTIES / hnnx::register_tensor_properties. -// Workaround: prevent the QNN backend from being freed. -class GlobalBackend { -public: - QnnImplementation implementation_; - QnnLogger* logger_; - QnnBackend* backend_; - - static GlobalBackend& GetInstance() { - static GlobalBackend instance; - return instance; - } - ~GlobalBackend() { - if (backend_) { - delete backend_; - backend_ = nullptr; - } - if (logger_) { - delete logger_; - logger_ = nullptr; - } - } -private: - GlobalBackend() - : implementation_("libQnnHtp.so") { - implementation_.Load(nullptr); - logger_ = new QnnLogger( - implementation_, LoggingCallback, QnnExecuTorchLogLevel::kLogLevelWarn); - backend_ = new HtpBackend(implementation_, logger_); - Error error = backend_->Configure(); - if (error != Error::Ok) { - QNN_EXECUTORCH_LOG_ERROR( - "Failed to configure backend. Error code: %d", error); - } - }; -}; -} - using executorch::runtime::Error; bool CompareExportedInput( @@ -153,8 +111,7 @@ QnnManager::QnnManager( break; } } - // qnn_loaded_backend_ = QnnImplementation(library_path); - qnn_loaded_backend_ = GlobalBackend::GetInstance().implementation_; + qnn_loaded_backend_ = QnnImplementation(library_path); backend_params_ptr_ = std::make_unique(); qnn_dlc_manager_ = @@ -232,19 +189,6 @@ Error QnnManager::RegisterMem( void* custom_mem_base = shared_buffer_manager.GetCustomMemBase(data_ptr); if (custom_mem_base != nullptr) { - size_t tensor_bytes = 0; - for (const auto& info : shared_buffer_manager.GetCustomMemTensorInfoSet()) { - if (info.tensor_addr == data_ptr) { - tensor_bytes = info.tensor_bytes; - } - } - if (tensor_bytes != tensor_wrapper->GetBytes()) { - QNN_EXECUTORCH_LOG_WARN( - "Tensor %s size %u is not equal to custom mem size %zu\n", - tensor_wrapper->GetName().c_str(), - tensor_wrapper->GetBytes(), - tensor_bytes); - } return RegisterCustomMem(data_ptr, custom_mem_base, tensor_wrapper); } return RegisterIonMem(data_ptr, tensor_wrapper); @@ -375,8 +319,7 @@ Error QnnManager::Init() { logger_.get(), qnn_context_blob_, options_, - qnn_dlc_manager_.get(), - std::unique_ptr(GlobalBackend::GetInstance().backend_)); + qnn_dlc_manager_.get()); ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_ != nullptr, Internal, @@ -572,17 +515,11 @@ Error QnnManager::ProfileExecuteData( void QnnManager::Destroy() { QNN_EXECUTORCH_LOG_INFO("Destroy Qnn backend parameters"); - if (backend_params_ptr_->qnn_backend_ptr_ != nullptr) { - GlobalBackend::GetInstance().backend_ = backend_params_ptr_->qnn_backend_ptr_.release(); - } - if (logger_ != nullptr) { - GlobalBackend::GetInstance().logger_ = logger_.release(); - } backend_params_ptr_.reset(new BackendConfigParameters()); qnn_dlc_manager_->ResetBackendParams(); logger_.reset(); qnn_dlc_manager_->ResetLogger(); - // qnn_loaded_backend_.TerminateAllBackends(); + qnn_loaded_backend_.TerminateAllBackends(); qnn_dlc_manager_->TerminateAllBackends(); } diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp index 85e768b07ff..e7e9db6fed8 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp @@ -20,8 +20,7 @@ std::unique_ptr QnnBackendFactory::Create( QnnLogger* logger, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, - QnnDlcManager* qnn_dlc_manager, - std::unique_ptr&& backend_ptr) { + QnnDlcManager* qnn_dlc_manager) { auto backend_params = std::make_unique(); switch (options->backend_options()->backend_type()) { @@ -57,9 +56,8 @@ std::unique_ptr QnnBackendFactory::Create( QNN_EXECUTORCH_LOG_INFO( "use_fold_relu in htp_options: %d", htp_options->use_fold_relu()); } - // backend_params->qnn_backend_ptr_ = - // std::make_unique(implementation, logger); - backend_params->qnn_backend_ptr_ = std::move(backend_ptr); + backend_params->qnn_backend_ptr_ = + std::make_unique(implementation, logger); backend_params->qnn_device_ptr_ = std::make_unique( implementation, logger, options->soc_info(), htp_options); diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.h b/backends/qualcomm/runtime/backends/QnnBackendFactory.h index ae4c3562284..3d78a36b9f0 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.h +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.h @@ -70,8 +70,7 @@ class QnnBackendFactory { QnnLogger* logger, const QnnExecuTorchContextBinary& qnn_context_blob, const QnnExecuTorchOptions* options, - QnnDlcManager* qnn_dlc_manager, - std::unique_ptr&& backend_ptr); + QnnDlcManager* qnn_dlc_manager); }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 2a9bcf71694..83ce4a7369d 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -110,7 +110,6 @@ if [ "$BUILD_ANDROID" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DANDROID_PLATFORM=android-30 \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ -B$BUILD_ROOT cmake --build $BUILD_ROOT -j$BUILD_JOB_NUMBER --target install @@ -130,7 +129,6 @@ if [ "$BUILD_ANDROID" = true ]; then -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER @@ -266,7 +264,6 @@ if [ "$BUILD_X86_64" = true ]; then -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ - -DSUPPORT_REGEX_LOOKAHEAD=ON \ -S $PRJ_ROOT \ -B $BUILD_ROOT \ diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers new file mode 160000 index 00000000000..3aada3fe28c --- /dev/null +++ b/extension/llm/tokenizers @@ -0,0 +1 @@ +Subproject commit 3aada3fe28c945d14d5ec62254eb56ccdf10eb11