Skip to content

Commit e5697d1

Browse files
authored
[Kernel] [Triton] [AMD] Adding Triton implementations awq_dequantize and awq_gemm to support AWQ (#7386)
1 parent b98cc28 commit e5697d1

File tree

5 files changed

+493
-1
lines changed

5 files changed

+493
-1
lines changed

tests/kernels/test_awq_triton.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
"""Tests for the AWQ Triton kernel.
2+
3+
Run `pytest tests/kernels/test_awq_triton.py`.
4+
"""
5+
import pytest
6+
import torch
7+
8+
from vllm.model_executor.layers.quantization.awq_triton import (
9+
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
10+
11+
device = "cuda"
12+
13+
14+
def reverse_awq_order(t: torch.Tensor):
15+
bits = 4
16+
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
17+
reverse_order_tensor = torch.arange(
18+
t.shape[-1],
19+
dtype=torch.int32,
20+
device=t.device,
21+
)
22+
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
23+
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
24+
reverse_order_tensor = reverse_order_tensor.view(-1)
25+
26+
t = t[:, reverse_order_tensor] & 0xF
27+
return t
28+
29+
30+
# qweights - [R , C // 8], int32
31+
# scales - [R // G, C ], float16
32+
# zeros - [R // G, C // 8], int32
33+
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
34+
qzeros: torch.Tensor,
35+
group_size: int) -> torch.Tensor:
36+
37+
if group_size == -1:
38+
group_size = qweight.shape[0]
39+
40+
bits = 4
41+
shifts = torch.arange(0, 32, bits, device=qzeros.device)
42+
43+
iweights = torch.bitwise_right_shift(qweight[:, :, None],
44+
shifts[None, None, :]).to(torch.int8)
45+
46+
iweights = iweights.view(iweights.shape[0], -1)
47+
48+
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
49+
shifts[None, None, :]).to(torch.int8)
50+
zeros = zeros.view(qzeros.shape[0], -1)
51+
zeros = reverse_awq_order(zeros)
52+
53+
iweights = reverse_awq_order(iweights)
54+
55+
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
56+
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
57+
58+
scales = scales.repeat_interleave(group_size, dim=0)
59+
zeros = zeros.repeat_interleave(group_size, dim=0)
60+
return (iweights - zeros) * scales
61+
62+
63+
# qweights - [R , C // 8], int32
64+
# scales - [R // G, C ], float16
65+
# zeros - [R // G, C // 8], int32
66+
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
67+
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
68+
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
69+
def test_dequantize(qweight_rows, qweight_cols, group_size):
70+
71+
if group_size == -1:
72+
group_size = qweight_rows
73+
74+
qweight_dtype = torch.int32
75+
scales_rows = qweight_rows // group_size
76+
scales_cols = qweight_cols * 8
77+
scales_dtype = torch.float16
78+
zeros_rows = scales_rows
79+
zeros_cols = qweight_cols
80+
zeros_dtype = torch.int32
81+
82+
torch.manual_seed(0)
83+
84+
qweight = torch.randint(0,
85+
torch.iinfo(torch.int32).max,
86+
(qweight_rows, qweight_cols),
87+
dtype=qweight_dtype,
88+
device=device)
89+
scales = torch.rand(scales_rows,
90+
scales_cols,
91+
dtype=scales_dtype,
92+
device=device)
93+
zeros = torch.randint(0,
94+
torch.iinfo(torch.int32).max,
95+
(zeros_rows, zeros_cols),
96+
dtype=zeros_dtype,
97+
device=device)
98+
99+
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
100+
101+
assert (not torch.any(torch.isinf(iweights_triton))
102+
and not torch.any(torch.isnan(iweights_triton)))
103+
104+
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
105+
106+
torch.testing.assert_close(iweights_triton, iweights_torch)
107+
108+
109+
# input - [N, K]
110+
# qweight - [K, M // 8]
111+
# qzeros - [K // G, M // 8]
112+
# scales - [K // G, M]
113+
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
114+
@pytest.mark.parametrize("K", [128])
115+
@pytest.mark.parametrize("M", [16, 24, 32])
116+
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
117+
@pytest.mark.parametrize("splitK", [1, 8])
118+
def test_gemm(N, K, M, splitK, group_size):
119+
120+
if group_size == -1:
121+
group_size = K
122+
123+
split_k_iters = splitK
124+
125+
input_rows = N
126+
input_cols = K
127+
input_dtype = torch.float32
128+
qweight_rows = input_cols
129+
qweight_cols = M // 8
130+
scales_rows = qweight_rows // group_size
131+
scales_cols = M
132+
scales_dtype = torch.float32
133+
qzeros_rows = scales_rows
134+
qzeros_cols = qweight_cols
135+
136+
torch.manual_seed(0)
137+
138+
input = torch.rand((input_rows, input_cols),
139+
dtype=input_dtype,
140+
device=device)
141+
qweight = torch.randint(0,
142+
torch.iinfo(torch.int32).max,
143+
(qweight_rows, qweight_cols),
144+
device=device)
145+
qzeros = torch.randint(0,
146+
torch.iinfo(torch.int32).max,
147+
(qzeros_rows, qzeros_cols),
148+
device=device)
149+
scales = torch.rand((scales_rows, scales_cols),
150+
dtype=scales_dtype,
151+
device=device)
152+
153+
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
154+
split_k_iters)
155+
156+
assert (not torch.any(torch.isinf(output_triton))
157+
and not torch.any(torch.isnan(output_triton)))
158+
159+
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
160+
161+
output_torch = torch.matmul(input, dequantized_weights)
162+
163+
assert (not torch.any(torch.isinf(output_torch))
164+
and not torch.any(torch.isnan(output_torch)))
165+
166+
torch.testing.assert_close(output_triton.cpu(),
167+
output_torch.cpu(),
168+
atol=1e-1,
169+
rtol=1e-1)

vllm/_custom_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
import vllm.envs as envs
78
from vllm._core_ext import ScalarType
89
from vllm.logger import init_logger
910
from vllm.platforms import current_platform
@@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
177178
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
178179
zeros: torch.Tensor, split_k_iters: int, thx: int,
179180
thy: int) -> torch.Tensor:
181+
if envs.VLLM_USE_TRITON_AWQ:
182+
from vllm.model_executor.layers.quantization.awq_triton import (
183+
awq_dequantize_triton)
184+
return awq_dequantize_triton(qweight, scales, zeros)
180185
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
181186
thx, thy)
182187

183188

184189
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
185190
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
191+
if envs.VLLM_USE_TRITON_AWQ:
192+
from vllm.model_executor.layers.quantization.awq_triton import (
193+
awq_gemm_triton)
194+
return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
186195
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
187196

188197

vllm/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def _parse_quant_hf_config(self):
267267

268268
def _verify_quantization(self) -> None:
269269
supported_quantization = [*QUANTIZATION_METHODS]
270-
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
270+
rocm_supported_quantization = ["awq", "gptq", "squeezellm", "fp8"]
271271
optimized_quantization_methods = [
272272
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
273273
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
@@ -322,6 +322,12 @@ def _verify_quantization(self) -> None:
322322
"%s quantization is not fully "
323323
"optimized yet. The speed can be slower than "
324324
"non-quantized models.", self.quantization)
325+
if (self.quantization == "awq" and is_hip()
326+
and not envs.VLLM_USE_TRITON_AWQ):
327+
logger.warning(
328+
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
329+
" is not set, enabling VLLM_USE_TRITON_AWQ.")
330+
envs.VLLM_USE_TRITON_AWQ = True
325331

326332
def _verify_cuda_graph(self) -> None:
327333
if self.max_seq_len_to_capture is None:

vllm/envs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,10 @@ def get_default_config_root():
400400
"VLLM_TORCH_PROFILER_DIR":
401401
lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os
402402
.path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))),
403+
404+
# If set, vLLM will use Triton implementations of AWQ.
405+
"VLLM_USE_TRITON_AWQ":
406+
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
403407
}
404408

405409
# end-env-vars-definition

0 commit comments

Comments
 (0)