From 6c65d3e807327085d0df7ea42876653e9d30902f Mon Sep 17 00:00:00 2001 From: Annanya Date: Sun, 26 Jan 2025 19:20:58 -0500 Subject: [PATCH] Added support for AWQ models --- examples/batch_generation.py | 11 +- examples/bench.py | 11 +- examples/generation.py | 10 +- install.sh | 0 models/llama.py | 204 +++++++++++++++++++++-------------- quantization/__init__.py | 0 quantization/awq_utils.py | 87 +++++++++++++++ requirements.txt | 3 +- 8 files changed, 241 insertions(+), 85 deletions(-) mode change 100644 => 100755 install.sh create mode 100644 quantization/__init__.py create mode 100644 quantization/awq_utils.py diff --git a/examples/batch_generation.py b/examples/batch_generation.py index eb1e66e..3c9e4e6 100644 --- a/examples/batch_generation.py +++ b/examples/batch_generation.py @@ -1,6 +1,6 @@ import sys sys.path.append("..") -from models.llama import LLM +from models.llama import LLM, LLMAwq import argparse import torch from transformers import AutoTokenizer @@ -14,6 +14,7 @@ parser.add_argument('--G', type=int, default=32, help='generation length') parser.add_argument('--K', type=int, default=10, help='K') parser.add_argument('--L', type=int, default=150, help='K') +parser.add_argument('--awq', action='store_true', help='use LLMAwq') args = parser.parse_args() print(args) MAX_LEN = args.M @@ -32,7 +33,13 @@ data = item break -llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=BATCH_SIZE, device=DEVICE, dtype=DTYPE) +if args.awq: + print("Using LLMAwq for AWQ optimization.") + llm = LLMAwq(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=BATCH_SIZE, device=DEVICE, dtype=DTYPE) +else: + print("Using standard LLM.") + llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=BATCH_SIZE, device=DEVICE, dtype=DTYPE) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) text = data["input"] input_ids = tokenizer.encode(text=text, return_tensors="pt").to(device=DEVICE) diff --git a/examples/bench.py b/examples/bench.py index 9884a9a..756e8a8 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -1,6 +1,6 @@ import sys sys.path.append("..") -from models.llama import LLM +from models.llama import LLM, LLMAwq import argparse import torch from transformers import AutoTokenizer @@ -15,6 +15,7 @@ parser.add_argument('--G', type=int, default=128, help='generation length') parser.add_argument('--K', type=int, default=10, help='K') parser.add_argument('--L', type=int, default=150, help='L') +parser.add_argument('--awq', action='store_true', help='use LLMAwq') args = parser.parse_args() print(args) MAX_LEN = args.M @@ -33,7 +34,13 @@ data = item break -llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=B, device=DEVICE, dtype=DTYPE) +if args.awq: + print("Using LLMAwq for AWQ optimization.") + llm = LLMAwq(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=B, device=DEVICE, dtype=DTYPE) +else: + print("Using standard LLM.") + llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=B, device=DEVICE, dtype=DTYPE) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) text = data["input"] input_ids = tokenizer.encode(text=text, return_tensors="pt").to(device=DEVICE) diff --git a/examples/generation.py b/examples/generation.py index a325be0..00f6a5d 100644 --- a/examples/generation.py +++ b/examples/generation.py @@ -1,6 +1,6 @@ import sys sys.path.append("..") -from models.llama import LLM +from models.llama import LLM, LLMAwq import argparse import torch from transformers import AutoTokenizer @@ -16,6 +16,7 @@ parser.add_argument('--L', type=int, default=150, help='K') parser.add_argument('--data', type=str, default="../data/story.txt", help='source data file') parser.add_argument('--template', type=str, default="meta-llama3", help='chat template') +parser.add_argument('--awq', action='store_true', help='use LLMAwq') args = parser.parse_args() print(args) MAX_LEN = args.M @@ -25,7 +26,12 @@ DTYPE = torch.bfloat16 DEVICE = "cuda:0" chat_template = Templates[args.template] -llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=1, device=DEVICE, dtype=DTYPE, generation_buffer=args.G + 32) +if args.awq: + print("Using LLMAwq for AWQ optimization.") + llm = LLMAwq(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=1, device=DEVICE, dtype=DTYPE, generation_buffer=args.G + 32) +else: + print("Using standard LLM.") + llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=1, device=DEVICE, dtype=DTYPE, generation_buffer=args.G + 32) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) with open(args.data, "r", encoding="utf-8") as file: content = file.read() diff --git a/install.sh b/install.sh old mode 100644 new mode 100755 diff --git a/models/llama.py b/models/llama.py index 5f32898..3442cf3 100644 --- a/models/llama.py +++ b/models/llama.py @@ -7,48 +7,100 @@ import flashinfer from .attnserver import LSHSparseAttnServer, AttnServer import time -class LLMLayer: - def __init__(self, layer_idx) -> None: - - self.wq :torch.Tensor = None - self.wk :torch.Tensor = None - self.wv :torch.Tensor = None - self.wo :torch.Tensor = None +from ..quantization.awq_utils import AwqLinear +from abc import ABC, abstractmethod + +class BaseLLMLayer(ABC): + def __init__(self, layer_idx: int) -> None: + # Common parameters + self.wq = None + self.wk = None + self.wv = None + self.wo = None - self.gate_proj :torch.Tensor = None - self.up_proj :torch.Tensor = None - self.down_proj :torch.Tensor = None + self.gate_proj = None + self.up_proj = None + self.down_proj = None - self.input_layernorm_weight :torch.Tensor = None - self.input_layernorm_variance_epsilon :float = 0.0 + self.input_layernorm_weight: torch.Tensor = None + self.input_layernorm_variance_epsilon: float = 0.0 - self.post_attention_layernorm_weight :torch.Tensor = None - self.post_attention_layernorm_variance_epsilon :float = 0.0 + self.post_attention_layernorm_weight: torch.Tensor = None + self.post_attention_layernorm_variance_epsilon: float = 0.0 - self.cos_cache :torch.Tensor = None - self.sin_cache :torch.Tensor = None + self.cos_cache: torch.Tensor = None + self.sin_cache: torch.Tensor = None self.layer_idx = layer_idx - - def init_parameters(self, hf_layer: LlamaDecoderLayer): - self.wq :torch.Tensor= hf_layer.self_attn.q_proj.weight.detach() - self.wk :torch.Tensor= hf_layer.self_attn.k_proj.weight.detach() - self.wv :torch.Tensor= hf_layer.self_attn.v_proj.weight.detach() - self.wo :torch.Tensor= hf_layer.self_attn.o_proj.weight.detach() + @abstractmethod + def init_parameters(self, hf_layer): + """Abstract method to initialize parameters from a given layer.""" + pass - self.gate_proj = hf_layer.mlp.gate_proj.weight.detach() - self.up_proj = hf_layer.mlp.up_proj.weight.detach() - self.down_proj = hf_layer.mlp.down_proj.weight.detach() + @abstractmethod + def init_gpu(self, device: str = 'cuda:0'): + """Abstract method to move parameters to GPU.""" + pass - self.input_layernorm_weight = hf_layer.input_layernorm.weight.detach() +class LLMLayer(BaseLLMLayer): + def init_parameters(self, hf_layer): + # Initialize parameters with type hints + self.wq: torch.Tensor = hf_layer.self_attn.q_proj.weight.detach() + self.wk: torch.Tensor = hf_layer.self_attn.k_proj.weight.detach() + self.wv: torch.Tensor = hf_layer.self_attn.v_proj.weight.detach() + self.wo: torch.Tensor = hf_layer.self_attn.o_proj.weight.detach() + + self.gate_proj: torch.Tensor = hf_layer.mlp.gate_proj.weight.detach() + self.up_proj: torch.Tensor = hf_layer.mlp.up_proj.weight.detach() + self.down_proj: torch.Tensor = hf_layer.mlp.down_proj.weight.detach() + + self.input_layernorm_weight: torch.Tensor = hf_layer.input_layernorm.weight.detach() self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon - self.post_attention_layernorm_weight = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_weight: torch.Tensor = hf_layer.post_attention_layernorm.weight.detach() self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon - - def init_gpu(self, device:str = 'cuda:0'): + def init_gpu(self, device: str = 'cuda:0'): + # Move parameters to the specified GPU device + self.input_layernorm_weight: torch.Tensor = self.input_layernorm_weight.to(device, non_blocking=True) + self.post_attention_layernorm_weight: torch.Tensor = self.post_attention_layernorm_weight.to(device, non_blocking=True) + self.wq: torch.Tensor = self.wq.to(device, non_blocking=True) + self.wk: torch.Tensor = self.wk.to(device, non_blocking=True) + self.wv: torch.Tensor = self.wv.to(device, non_blocking=True) + self.wo: torch.Tensor = self.wo.to(device, non_blocking=True) + self.gate_proj: torch.Tensor = self.gate_proj.to(device, non_blocking=True) + self.up_proj: torch.Tensor = self.up_proj.to(device, non_blocking=True) + self.down_proj: torch.Tensor = self.down_proj.to(device, non_blocking=True) + +class LLMAwqLayer(BaseLLMLayer): + def __init__(self, layer_idx: int) -> None: + super().__init__(layer_idx) + self.wq = AwqLinear() + self.wk = AwqLinear() + self.wv = AwqLinear() + self.wo = AwqLinear() + + self.gate_proj = AwqLinear() + self.up_proj = AwqLinear() + self.down_proj = AwqLinear() + + def init_parameters(self, hf_layer): + self.wq.init_parameters(hf_layer.self_attn.q_proj) + self.wk.init_parameters(hf_layer.self_attn.k_proj) + self.wv.init_parameters(hf_layer.self_attn.v_proj) + self.wo.init_parameters(hf_layer.self_attn.o_proj) + self.gate_proj.init_parameters(hf_layer.mlp.gate_proj) + self.up_proj.init_parameters(hf_layer.mlp.up_proj) + self.down_proj.init_parameters(hf_layer.mlp.down_proj) + + self.input_layernorm_weight: torch.Tensor = hf_layer.input_layernorm.weight.detach() + self.input_layernorm_variance_epsilon = hf_layer.input_layernorm.variance_epsilon + + self.post_attention_layernorm_weight: torch.Tensor = hf_layer.post_attention_layernorm.weight.detach() + self.post_attention_layernorm_variance_epsilon = hf_layer.post_attention_layernorm.variance_epsilon + + def init_gpu(self, device: str = 'cuda:0'): self.input_layernorm_weight = self.input_layernorm_weight.to(device, non_blocking=True) self.post_attention_layernorm_weight = self.post_attention_layernorm_weight.to(device, non_blocking=True) self.wq = self.wq.to(device, non_blocking=True) @@ -57,7 +109,7 @@ def init_gpu(self, device:str = 'cuda:0'): self.wo = self.wo.to(device, non_blocking=True) self.gate_proj = self.gate_proj.to(device, non_blocking=True) self.up_proj = self.up_proj.to(device, non_blocking=True) - self.down_proj = self.down_proj.to(device, non_blocking=True) + self.down_proj = self.down_proj.to(device, non_blocking=True) @@ -99,6 +151,10 @@ def __init__(self, self.v_cache = torch.zeros((max_length, self.num_key_value_heads, self.head_dim), dtype=self.dtype, device=self.device) self.chunk_size = 8192 self.wrt_stream = torch.cuda.Stream() + + def create_layer(self, idx): + return LLMLayer(idx) + def init_parameters(self): hf_model = LlamaForCausalLM.from_pretrained(self.model_name, torch_dtype=self.dtype) @@ -122,11 +178,11 @@ def init_parameters(self): self.sin_cache = self.sin_cache * self.attention_scaling self.cos_cache = self.cos_cache.to(self.dtype) self.sin_cache = self.sin_cache.to(self.dtype) - self.layers :list[LLMLayer] = [] + self.layers :list[BaseLLMLayer] = [] for idx, hf_layer in enumerate(hf_model.model.layers): - layer = LLMLayer(idx) + layer = self.create_layer(idx) layer.init_parameters(hf_layer=hf_layer) layer.init_gpu(self.device) self.layers.append(layer) @@ -134,24 +190,25 @@ def init_parameters(self): gc.collect() self.num_layers = len(self.layers) + + def apply_linear_layer(self, input, w): + assert isinstance(input, torch.Tensor), f"Expected input to be torch.Tensor, got {type(input)}" + assert isinstance(w, torch.Tensor), f"Expected w to be torch.Tensor, got {type(w)}" + return F.linear(input, w) def pre_attention_compute( self, hidden_states: torch.Tensor, - input_layernorm_variance_epsilon: float, - input_layernorm_weight: torch.Tensor, - wq:torch.Tensor, - wk:torch.Tensor, - wv:torch.Tensor, + layer: BaseLLMLayer, num_heads:int, num_key_value_heads:int, head_dim:int ): - hidden_states = layer_norm(hidden_states, input_layernorm_variance_epsilon, input_layernorm_weight) + hidden_states = layer_norm(hidden_states, layer.input_layernorm_variance_epsilon, layer.input_layernorm_weight) bsz, q_len, _ = hidden_states.size() - query_states = F.linear(hidden_states, wq) - key_states = F.linear(hidden_states, wk) - value_states = F.linear(hidden_states, wv) + query_states = self.apply_linear_layer(hidden_states, layer.wq) + key_states = self.apply_linear_layer(hidden_states, layer.wk) + value_states = self.apply_linear_layer(hidden_states, layer.wv) query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, num_key_value_heads, head_dim).transpose(1, 2) @@ -160,30 +217,22 @@ def post_attention_compute( self, attn_output: torch.Tensor, residual: torch.Tensor, - post_attention_layernorm_variance_epsilon: float, - post_attention_layernorm_weight: torch.Tensor, - wo: torch.Tensor, - gate_proj: torch.Tensor, - up_proj: torch.Tensor, - down_proj: torch.Tensor, + layer: BaseLLMLayer ): - - - hidden_states = F.linear(attn_output, wo) - + hidden_states = F.linear(attn_output, layer.wo) hidden_states = residual + hidden_states residual = hidden_states - hidden_states = layer_norm(hidden_states, post_attention_layernorm_variance_epsilon, post_attention_layernorm_weight) - up = F.linear(hidden_states, up_proj) - gate = F.linear(hidden_states, gate_proj) + hidden_states = layer_norm(hidden_states, layer.post_attention_layernorm_variance_epsilon, layer.post_attention_layernorm_weight) + up = self.apply_linear_layer(hidden_states, layer.up_proj) + gate = self.apply_linear_layer(hidden_states, layer.gate_proj) gate = F.silu(gate) hidden_states = gate * up - hidden_states = F.linear(hidden_states, down_proj) + hidden_states = self.apply_linear_layer(hidden_states, layer.down_proj) hidden_states = residual + hidden_states return hidden_states @torch.inference_mode() def layer_compute(self, - buffer: LLMLayer, + buffer: BaseLLMLayer, layer_idx :int, hidden_states: torch.FloatTensor, position_ids: torch.LongTensor): @@ -191,11 +240,7 @@ def layer_compute(self, residual = hidden_states query_states, key_states, value_states = self.pre_attention_compute( hidden_states, - buffer.input_layernorm_variance_epsilon, - buffer.input_layernorm_weight, - buffer.wq, - buffer.wk, - buffer.wv, + buffer, self.num_heads, self.num_key_value_heads, self.head_dim @@ -207,21 +252,13 @@ def layer_compute(self, hidden_states = self.attention_server.decode(query_states, key_states, value_states, layer_idx) - hidden_states = self.post_attention_compute( - hidden_states, residual, - buffer.post_attention_layernorm_variance_epsilon, - buffer.post_attention_layernorm_weight, - buffer.wo, - buffer.gate_proj, - buffer.up_proj, - buffer.down_proj, - ) + hidden_states = self.post_attention_compute(hidden_states, residual, buffer) return hidden_states @torch.inference_mode() def layer_prefill(self, - buffer: LLMLayer, + buffer: BaseLLMLayer, layer_idx :int, hidden_states: torch.FloatTensor, position_ids: torch.LongTensor, @@ -233,9 +270,9 @@ def layer_prefill(self, for (start, end) in zip(self.chunk_start, self.chunk_end): h = layer_norm(hidden_states[:,start:end,:], buffer.input_layernorm_variance_epsilon, buffer.input_layernorm_weight) bsz, q_len, _ = h.size() - query_states = F.linear(h, buffer.wq) - key_states = F.linear(h, buffer.wk) - value_states = F.linear(h, buffer.wv) + query_states = self.apply_linear_layer(h, buffer.wq) + key_states = self.apply_linear_layer(h, buffer.wk) + value_states = self.apply_linear_layer(h, buffer.wv) query_states = query_states.view(q_len, self.num_heads, self.head_dim) key_states = key_states.view(q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(q_len, self.num_key_value_heads, self.head_dim) @@ -257,7 +294,7 @@ def layer_prefill(self, ) h = h.reshape(bsz, q_len, self.hidden_size) - h = F.linear(h, buffer.wo) + h = self.apply_linear_layer(h, buffer.wo) residual[:,start:end,:].add_(h) if layer_idx >= 1: @@ -272,11 +309,11 @@ def layer_prefill(self, hidden_states = residual for (start, end) in zip(self.chunk_start, self.chunk_end): h = layer_norm(hidden_states[:,start:end,:], buffer.post_attention_layernorm_variance_epsilon, buffer.post_attention_layernorm_weight) - up = F.linear(h, buffer.up_proj) - gate = F.linear(h, buffer.gate_proj) + up = self.apply_linear_layer(h, buffer.up_proj) + gate = self.apply_linear_layer(h, buffer.gate_proj) gate = F.silu(gate) h = gate * up - h = F.linear(h, buffer.down_proj) + h = self.apply_linear_layer(h, buffer.down_proj) residual[:,start:end,:].add_(h) self.attention_server.fill(layer_idx, request_id, self.k_cache, self.v_cache, self.chunk_end[-1]) @@ -363,3 +400,14 @@ def clear(self): self.attention_server.clear() self.k_cache.zero_() self.v_cache.zero_() + + +## This class is used for activation aware quantisation +class LLMAwq(LLM): + def create_layer(self, idx): + return LLMAwqLayer(idx) + + def apply_linear_layer(self, input, w): + assert isinstance(input, torch.Tensor), f"Expected input to be torch.Tensor, got {type(input)}" + assert isinstance(w, AwqLinear), f"Expected w to be torch.Tensor or AwqLinear, got {type(w)}" + return w.apply(input) diff --git a/quantization/__init__.py b/quantization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quantization/awq_utils.py b/quantization/awq_utils.py new file mode 100644 index 0000000..09baab3 --- /dev/null +++ b/quantization/awq_utils.py @@ -0,0 +1,87 @@ +from __future__ import annotations +import torch +from awq.modules.linear import WQLinear_GEMM +import awq_ext +class AwqLinear: + def __init__(self): + + self.in_features = 0 + self.out_features = 0 + self.w_bit = 0 + self.group_size = 0 + self.qweight :torch.Tensor = None + self.qzeros :torch.Tensor = None + self.scales :torch.Tensor = None + self.bias :torch.Tensor = None + + + def init_parameters(self, module: WQLinear_GEMM): + + self.in_features = module.in_features + self.out_features = module.out_features + self.w_bit = module.w_bit + self.group_size = module.group_size + self.qweight = module.qweight.detach().pin_memory() + self.qzeros = module.qzeros.detach().pin_memory() + self.scales = module.scales.detach().pin_memory() + if module.bias is not None: + self.bias = module.bias.detach() + else: + self.bias = None + + def empty_like(self, module: WQLinear_GEMM): + + self.in_features = module.in_features + self.out_features = module.out_features + self.w_bit = module.w_bit + self.group_size = module.group_size + + self.qweight = torch.zeros_like(module.qweight.detach()) + self.qzeros = torch.zeros_like(module.qzeros.detach()) + self.scales = torch.zeros_like(module.scales.detach()) + if module.bias is not None: + self.bias = torch.zeros_like(module.bias.detach()) + + def to(self, device, non_blocking=True): + + self.qweight = self.qweight.to(device, non_blocking=non_blocking) + self.qzeros = self.qzeros.to(device, non_blocking=non_blocking) + self.scales = self.scales.to(device, non_blocking=non_blocking) + if self.bias is not None: + self.bias = self.bias.to(device, non_blocking=non_blocking) + + + def copy(self, module: AwqLinear, non_blocking=True): + + self.qweight.copy_(module.qweight, non_blocking=non_blocking) + self.qzeros.copy_(module.qzeros, non_blocking=non_blocking) + self.scales.copy_(module.scales, non_blocking=non_blocking) + if self.bias is not None: + self.bias.copy_(module.bias, non_blocking=non_blocking) + + + def apply(self, x: torch.Tensor): + + out_shape = x.shape[:-1] + (self.out_features,) + + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda( + self.qweight, self.scales, self.qzeros, 0, 0, 0, False + ) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) + + + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8b892cb..67a2dcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ sentencepiece protobuf jsonlines pytest - +autoawq +autoawq-kernels