Skip to content

Commit e279ab7

Browse files
Add block ptr test for dot product with transpose (#4510)
Adds the end-to-end dot product on block ptr testing to the block load unit test (maybe it should be renamed `test_block_ptr.py`?). Adds additional shapes, A transpose, and B transpose. The cold runtime (no cache) is approximately 1 minute on PVC 1100 in my environment. I picked the block shapes somewhat randomly, trying to balance breadth and runtime. This somewhat duplicates tutorial 10 but allows us to run many more combinations in shorter time. I added this because #4463 is passing CI but has a few bugs that are not being caught by existing unit tests, including tutorials. --------- Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
1 parent 1cdfaf0 commit e279ab7

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import pytest
22
import torch
33
import pathlib
4+
from functools import partial
45

56
import triton
7+
import triton.language as tl
68
from triton._internal_testing import is_xpu
79

810

@@ -74,5 +76,131 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
7476
kernel = triton.compile(str(temp_file))
7577

7678
kernel[(1, 1, 1)](a, x, b, y)
77-
#import pdb; pdb.set_trace()
7879
assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)
80+
81+
82+
@pytest.mark.parametrize("BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K",
83+
[[256, 256, 32], [256, 64, 32], [64, 256, 32], [64, 128, 32], [64, 64, 32], [32, 32, 32],
84+
[32, 32, 16], [16, 16, 16], [8, 32, 16], [8, 512, 64]])
85+
@pytest.mark.parametrize("GROUP_SIZE_M", [4, 1])
86+
@pytest.mark.parametrize("TRANSPOSE_A", [True, False])
87+
@pytest.mark.parametrize("TRANSPOSE_B", [True, False])
88+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
89+
@pytest.mark.xfail(
90+
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
91+
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
92+
reason="Block loads and/or DPAS not supported on this architecture")
93+
def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B,
94+
device):
95+
if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64):
96+
# skip large block sizes as they will be too slow
97+
pytest.xfail("Skipping slow combinations")
98+
99+
@triton.jit
100+
def matmul_kernel_with_block_pointers(
101+
# Pointers to matrices
102+
a_ptr, b_ptr, c_ptr,
103+
# Matrix dimensions
104+
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
105+
# The stride variables represent how much to increase the ptr by when moving by 1
106+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
107+
# by to get the element one row down (A has M rows).
108+
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
109+
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
110+
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
111+
# Meta-parameters
112+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
113+
"""Kernel for computing the matmul C = A x B.
114+
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
115+
"""
116+
# -----------------------------------------------------------
117+
# Map program ids `pid` to the block of C it should compute.
118+
# This is done in a grouped ordering to promote L2 data reuse.
119+
# See the matrix multiplication tutorial for details.
120+
pid = tl.program_id(axis=0)
121+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
122+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
123+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
124+
group_id = pid // num_pid_in_group
125+
first_pid_m = group_id * GROUP_SIZE_M
126+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
127+
pid_m = first_pid_m + (pid % group_size_m)
128+
pid_n = (pid % num_pid_in_group) // group_size_m
129+
130+
# ----------------------------------------------------------
131+
# Create block pointers for the first blocks of A and B.
132+
# We will advance this pointer as we move in the K direction and accumulate.
133+
# See above `Make a Block Pointer` section for details.
134+
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
135+
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
136+
order=(1, 0))
137+
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
138+
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
139+
order=(1, 0))
140+
141+
# -----------------------------------------------------------
142+
# Iterate to compute a block of the C matrix.
143+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
144+
# of fp32 values for higher accuracy.
145+
# `accumulator` will be converted back to fp16 after the loop.
146+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
147+
for k in range(0, K, BLOCK_SIZE_K):
148+
# Load with boundary checks, no need to calculate the mask manually.
149+
# For better performance, you may remove some axis from the boundary
150+
# check, if you can guarantee that the access is always in-bound in
151+
# that axis.
152+
# See above `Load/Store a Block Pointer` section for details.
153+
a = tl.load(a_block_ptr, boundary_check=(0, 1))
154+
b = tl.load(b_block_ptr, boundary_check=(0, 1))
155+
# We accumulate along the K dimension.
156+
accumulator += tl.dot(a, b)
157+
# Advance the block pointer to the next K block.
158+
# See above `Advance a Block Pointer` section for details.
159+
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
160+
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
161+
c = accumulator.to(tl.float32)
162+
163+
c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
164+
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
165+
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
166+
tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1))
167+
168+
def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
169+
if transpose_x:
170+
K, M = X.shape
171+
Xstride0, Xstride1 = X.stride(1), X.stride(0)
172+
else:
173+
M, K = X.shape
174+
Xstride0, Xstride1 = X.stride(0), X.stride(1)
175+
if transpose_y:
176+
N, _ = Y.shape
177+
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
178+
else:
179+
_, N = Y.shape
180+
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
181+
# Allocates output.
182+
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
183+
# 1D launch kernel where each block gets its own program.
184+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
185+
186+
matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0),
187+
Z.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
188+
BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=GROUP_SIZE_M)
189+
190+
return Z
191+
192+
M = 512
193+
K = 64
194+
N = 512
195+
dtype = torch.float16
196+
torch.manual_seed(0)
197+
198+
X = torch.randn((M, K) if not TRANSPOSE_A else (K, M), device=device, dtype=dtype, requires_grad=False)
199+
Y = torch.randn((K, N) if not TRANSPOSE_B else (N, K), device=device, dtype=dtype, requires_grad=False)
200+
201+
fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T)
202+
fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B)
203+
204+
result_tor = fn_tor()
205+
result_tri = fn_tri()
206+
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)

0 commit comments

Comments
 (0)