|
| 1 | +"""Tests for the AWQ Triton kernel. |
| 2 | +
|
| 3 | +Run `pytest tests/kernels/test_awq_triton.py`. |
| 4 | +""" |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.model_executor.layers.quantization.awq_triton import ( |
| 9 | + AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) |
| 10 | + |
| 11 | +device = "cuda" |
| 12 | + |
| 13 | + |
| 14 | +def reverse_awq_order(t: torch.Tensor): |
| 15 | + bits = 4 |
| 16 | + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] |
| 17 | + reverse_order_tensor = torch.arange( |
| 18 | + t.shape[-1], |
| 19 | + dtype=torch.int32, |
| 20 | + device=t.device, |
| 21 | + ) |
| 22 | + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) |
| 23 | + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] |
| 24 | + reverse_order_tensor = reverse_order_tensor.view(-1) |
| 25 | + |
| 26 | + t = t[:, reverse_order_tensor] & 0xF |
| 27 | + return t |
| 28 | + |
| 29 | + |
| 30 | +# qweights - [R , C // 8], int32 |
| 31 | +# scales - [R // G, C ], float16 |
| 32 | +# zeros - [R // G, C // 8], int32 |
| 33 | +def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, |
| 34 | + qzeros: torch.Tensor, |
| 35 | + group_size: int) -> torch.Tensor: |
| 36 | + |
| 37 | + if group_size == -1: |
| 38 | + group_size = qweight.shape[0] |
| 39 | + |
| 40 | + bits = 4 |
| 41 | + shifts = torch.arange(0, 32, bits, device=qzeros.device) |
| 42 | + |
| 43 | + iweights = torch.bitwise_right_shift(qweight[:, :, None], |
| 44 | + shifts[None, None, :]).to(torch.int8) |
| 45 | + |
| 46 | + iweights = iweights.view(iweights.shape[0], -1) |
| 47 | + |
| 48 | + zeros = torch.bitwise_right_shift(qzeros[:, :, None], |
| 49 | + shifts[None, None, :]).to(torch.int8) |
| 50 | + zeros = zeros.view(qzeros.shape[0], -1) |
| 51 | + zeros = reverse_awq_order(zeros) |
| 52 | + |
| 53 | + iweights = reverse_awq_order(iweights) |
| 54 | + |
| 55 | + iweights = torch.bitwise_and(iweights, (2**bits) - 1) |
| 56 | + zeros = torch.bitwise_and(zeros, (2**bits) - 1) |
| 57 | + |
| 58 | + scales = scales.repeat_interleave(group_size, dim=0) |
| 59 | + zeros = zeros.repeat_interleave(group_size, dim=0) |
| 60 | + return (iweights - zeros) * scales |
| 61 | + |
| 62 | + |
| 63 | +# qweights - [R , C // 8], int32 |
| 64 | +# scales - [R // G, C ], float16 |
| 65 | +# zeros - [R // G, C // 8], int32 |
| 66 | +@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024]) |
| 67 | +@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) |
| 68 | +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) |
| 69 | +def test_dequantize(qweight_rows, qweight_cols, group_size): |
| 70 | + |
| 71 | + if group_size == -1: |
| 72 | + group_size = qweight_rows |
| 73 | + |
| 74 | + qweight_dtype = torch.int32 |
| 75 | + scales_rows = qweight_rows // group_size |
| 76 | + scales_cols = qweight_cols * 8 |
| 77 | + scales_dtype = torch.float16 |
| 78 | + zeros_rows = scales_rows |
| 79 | + zeros_cols = qweight_cols |
| 80 | + zeros_dtype = torch.int32 |
| 81 | + |
| 82 | + torch.manual_seed(0) |
| 83 | + |
| 84 | + qweight = torch.randint(0, |
| 85 | + torch.iinfo(torch.int32).max, |
| 86 | + (qweight_rows, qweight_cols), |
| 87 | + dtype=qweight_dtype, |
| 88 | + device=device) |
| 89 | + scales = torch.rand(scales_rows, |
| 90 | + scales_cols, |
| 91 | + dtype=scales_dtype, |
| 92 | + device=device) |
| 93 | + zeros = torch.randint(0, |
| 94 | + torch.iinfo(torch.int32).max, |
| 95 | + (zeros_rows, zeros_cols), |
| 96 | + dtype=zeros_dtype, |
| 97 | + device=device) |
| 98 | + |
| 99 | + iweights_triton = awq_dequantize_triton(qweight, scales, zeros) |
| 100 | + |
| 101 | + assert (not torch.any(torch.isinf(iweights_triton)) |
| 102 | + and not torch.any(torch.isnan(iweights_triton))) |
| 103 | + |
| 104 | + iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) |
| 105 | + |
| 106 | + torch.testing.assert_close(iweights_triton, iweights_torch) |
| 107 | + |
| 108 | + |
| 109 | +# input - [N, K] |
| 110 | +# qweight - [K, M // 8] |
| 111 | +# qzeros - [K // G, M // 8] |
| 112 | +# scales - [K // G, M] |
| 113 | +@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32]) |
| 114 | +@pytest.mark.parametrize("K", [128]) |
| 115 | +@pytest.mark.parametrize("M", [16, 24, 32]) |
| 116 | +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) |
| 117 | +@pytest.mark.parametrize("splitK", [1, 8]) |
| 118 | +def test_gemm(N, K, M, splitK, group_size): |
| 119 | + |
| 120 | + if group_size == -1: |
| 121 | + group_size = K |
| 122 | + |
| 123 | + split_k_iters = splitK |
| 124 | + |
| 125 | + input_rows = N |
| 126 | + input_cols = K |
| 127 | + input_dtype = torch.float32 |
| 128 | + qweight_rows = input_cols |
| 129 | + qweight_cols = M // 8 |
| 130 | + scales_rows = qweight_rows // group_size |
| 131 | + scales_cols = M |
| 132 | + scales_dtype = torch.float32 |
| 133 | + qzeros_rows = scales_rows |
| 134 | + qzeros_cols = qweight_cols |
| 135 | + |
| 136 | + torch.manual_seed(0) |
| 137 | + |
| 138 | + input = torch.rand((input_rows, input_cols), |
| 139 | + dtype=input_dtype, |
| 140 | + device=device) |
| 141 | + qweight = torch.randint(0, |
| 142 | + torch.iinfo(torch.int32).max, |
| 143 | + (qweight_rows, qweight_cols), |
| 144 | + device=device) |
| 145 | + qzeros = torch.randint(0, |
| 146 | + torch.iinfo(torch.int32).max, |
| 147 | + (qzeros_rows, qzeros_cols), |
| 148 | + device=device) |
| 149 | + scales = torch.rand((scales_rows, scales_cols), |
| 150 | + dtype=scales_dtype, |
| 151 | + device=device) |
| 152 | + |
| 153 | + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, |
| 154 | + split_k_iters) |
| 155 | + |
| 156 | + assert (not torch.any(torch.isinf(output_triton)) |
| 157 | + and not torch.any(torch.isnan(output_triton))) |
| 158 | + |
| 159 | + dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) |
| 160 | + |
| 161 | + output_torch = torch.matmul(input, dequantized_weights) |
| 162 | + |
| 163 | + assert (not torch.any(torch.isinf(output_torch)) |
| 164 | + and not torch.any(torch.isnan(output_torch))) |
| 165 | + |
| 166 | + torch.testing.assert_close(output_triton.cpu(), |
| 167 | + output_torch.cpu(), |
| 168 | + atol=1e-1, |
| 169 | + rtol=1e-1) |
0 commit comments