|
1 | 1 | import pytest
|
2 | 2 | import torch
|
3 | 3 | import pathlib
|
| 4 | +from functools import partial |
4 | 5 |
|
5 | 6 | import triton
|
| 7 | +import triton.language as tl |
6 | 8 | from triton._internal_testing import is_xpu
|
7 | 9 |
|
8 | 10 |
|
@@ -74,5 +76,131 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
|
74 | 76 | kernel = triton.compile(str(temp_file))
|
75 | 77 |
|
76 | 78 | kernel[(1, 1, 1)](a, x, b, y)
|
77 |
| - #import pdb; pdb.set_trace() |
78 | 79 | assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)
|
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.parametrize("BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K", |
| 83 | + [[256, 256, 32], [256, 64, 32], [64, 256, 32], [64, 128, 32], [64, 64, 32], [32, 32, 32], |
| 84 | + [32, 32, 16], [16, 16, 16], [8, 32, 16], [8, 512, 64]]) |
| 85 | +@pytest.mark.parametrize("GROUP_SIZE_M", [4, 1]) |
| 86 | +@pytest.mark.parametrize("TRANSPOSE_A", [True, False]) |
| 87 | +@pytest.mark.parametrize("TRANSPOSE_B", [True, False]) |
| 88 | +@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend") |
| 89 | +@pytest.mark.xfail( |
| 90 | + not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io'] |
| 91 | + and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']), |
| 92 | + reason="Block loads and/or DPAS not supported on this architecture") |
| 93 | +def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B, |
| 94 | + device): |
| 95 | + if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64): |
| 96 | + # skip large block sizes as they will be too slow |
| 97 | + pytest.xfail("Skipping slow combinations") |
| 98 | + |
| 99 | + @triton.jit |
| 100 | + def matmul_kernel_with_block_pointers( |
| 101 | + # Pointers to matrices |
| 102 | + a_ptr, b_ptr, c_ptr, |
| 103 | + # Matrix dimensions |
| 104 | + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, |
| 105 | + # The stride variables represent how much to increase the ptr by when moving by 1 |
| 106 | + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` |
| 107 | + # by to get the element one row down (A has M rows). |
| 108 | + stride_am: tl.constexpr, stride_ak: tl.constexpr, # |
| 109 | + stride_bk: tl.constexpr, stride_bn: tl.constexpr, # |
| 110 | + stride_cm: tl.constexpr, stride_cn: tl.constexpr, # |
| 111 | + # Meta-parameters |
| 112 | + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): |
| 113 | + """Kernel for computing the matmul C = A x B. |
| 114 | + A has shape (M, K), B has shape (K, N) and C has shape (M, N) |
| 115 | + """ |
| 116 | + # ----------------------------------------------------------- |
| 117 | + # Map program ids `pid` to the block of C it should compute. |
| 118 | + # This is done in a grouped ordering to promote L2 data reuse. |
| 119 | + # See the matrix multiplication tutorial for details. |
| 120 | + pid = tl.program_id(axis=0) |
| 121 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 122 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 123 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 124 | + group_id = pid // num_pid_in_group |
| 125 | + first_pid_m = group_id * GROUP_SIZE_M |
| 126 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 127 | + pid_m = first_pid_m + (pid % group_size_m) |
| 128 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 129 | + |
| 130 | + # ---------------------------------------------------------- |
| 131 | + # Create block pointers for the first blocks of A and B. |
| 132 | + # We will advance this pointer as we move in the K direction and accumulate. |
| 133 | + # See above `Make a Block Pointer` section for details. |
| 134 | + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), |
| 135 | + offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), |
| 136 | + order=(1, 0)) |
| 137 | + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), |
| 138 | + offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), |
| 139 | + order=(1, 0)) |
| 140 | + |
| 141 | + # ----------------------------------------------------------- |
| 142 | + # Iterate to compute a block of the C matrix. |
| 143 | + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block. |
| 144 | + # of fp32 values for higher accuracy. |
| 145 | + # `accumulator` will be converted back to fp16 after the loop. |
| 146 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 147 | + for k in range(0, K, BLOCK_SIZE_K): |
| 148 | + # Load with boundary checks, no need to calculate the mask manually. |
| 149 | + # For better performance, you may remove some axis from the boundary |
| 150 | + # check, if you can guarantee that the access is always in-bound in |
| 151 | + # that axis. |
| 152 | + # See above `Load/Store a Block Pointer` section for details. |
| 153 | + a = tl.load(a_block_ptr, boundary_check=(0, 1)) |
| 154 | + b = tl.load(b_block_ptr, boundary_check=(0, 1)) |
| 155 | + # We accumulate along the K dimension. |
| 156 | + accumulator += tl.dot(a, b) |
| 157 | + # Advance the block pointer to the next K block. |
| 158 | + # See above `Advance a Block Pointer` section for details. |
| 159 | + a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) |
| 160 | + b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) |
| 161 | + c = accumulator.to(tl.float32) |
| 162 | + |
| 163 | + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), |
| 164 | + offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), |
| 165 | + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) |
| 166 | + tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1)) |
| 167 | + |
| 168 | + def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False): |
| 169 | + if transpose_x: |
| 170 | + K, M = X.shape |
| 171 | + Xstride0, Xstride1 = X.stride(1), X.stride(0) |
| 172 | + else: |
| 173 | + M, K = X.shape |
| 174 | + Xstride0, Xstride1 = X.stride(0), X.stride(1) |
| 175 | + if transpose_y: |
| 176 | + N, _ = Y.shape |
| 177 | + Wstride0, Wstride1 = Y.stride(1), Y.stride(0) |
| 178 | + else: |
| 179 | + _, N = Y.shape |
| 180 | + Wstride0, Wstride1 = Y.stride(0), Y.stride(1) |
| 181 | + # Allocates output. |
| 182 | + Z = torch.empty((M, N), device=X.device, dtype=X.dtype) |
| 183 | + # 1D launch kernel where each block gets its own program. |
| 184 | + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) |
| 185 | + |
| 186 | + matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0), |
| 187 | + Z.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, |
| 188 | + BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=GROUP_SIZE_M) |
| 189 | + |
| 190 | + return Z |
| 191 | + |
| 192 | + M = 512 |
| 193 | + K = 64 |
| 194 | + N = 512 |
| 195 | + dtype = torch.float16 |
| 196 | + torch.manual_seed(0) |
| 197 | + |
| 198 | + X = torch.randn((M, K) if not TRANSPOSE_A else (K, M), device=device, dtype=dtype, requires_grad=False) |
| 199 | + Y = torch.randn((K, N) if not TRANSPOSE_B else (N, K), device=device, dtype=dtype, requires_grad=False) |
| 200 | + |
| 201 | + fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T) |
| 202 | + fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B) |
| 203 | + |
| 204 | + result_tor = fn_tor() |
| 205 | + result_tri = fn_tri() |
| 206 | + torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3) |
0 commit comments