Skip to content

Add block ptr test for dot product with transpose #4510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion python/test/unit/intel/test_block_load.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import torch
import pathlib
from functools import partial

import triton
import triton.language as tl
from triton._internal_testing import is_xpu


Expand Down Expand Up @@ -74,5 +76,131 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
kernel = triton.compile(str(temp_file))

kernel[(1, 1, 1)](a, x, b, y)
#import pdb; pdb.set_trace()
assert torch.equal(a, x) and torch.equal(b.T if transpose else b, y)


@pytest.mark.parametrize("BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K",
[[256, 256, 32], [256, 64, 32], [64, 256, 32], [64, 128, 32], [64, 64, 32], [32, 32, 32],
[32, 32, 16], [16, 16, 16], [8, 32, 16], [8, 512, 64]])
@pytest.mark.parametrize("GROUP_SIZE_M", [4, 1])
@pytest.mark.parametrize("TRANSPOSE_A", [True, False])
@pytest.mark.parametrize("TRANSPOSE_B", [True, False])
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
@pytest.mark.xfail(
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
reason="Block loads and/or DPAS not supported on this architecture")
def test_block_load_dot_product(BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, TRANSPOSE_A, TRANSPOSE_B,
device):
if GROUP_SIZE_M == 1 and (BLOCK_SIZE_M > 64 or BLOCK_SIZE_N > 64):
# skip large block sizes as they will be too slow
pytest.xfail("Skipping slow combinations")

@triton.jit
def matmul_kernel_with_block_pointers(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am: tl.constexpr, stride_ak: tl.constexpr, #
stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See the matrix multiplication tutorial for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create block pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction and accumulate.
# See above `Make a Block Pointer` section for details.
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
offsets=(pid_m * BLOCK_SIZE_M, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
offsets=(0, pid_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(1, 0))

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block.
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# Load with boundary checks, no need to calculate the mask manually.
# For better performance, you may remove some axis from the boundary
# check, if you can guarantee that the access is always in-bound in
# that axis.
# See above `Load/Store a Block Pointer` section for details.
a = tl.load(a_block_ptr, boundary_check=(0, 1))
b = tl.load(b_block_ptr, boundary_check=(0, 1))
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the block pointer to the next K block.
# See above `Advance a Block Pointer` section for details.
a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
c = accumulator.to(tl.float32)

c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0))
tl.store(c_block_ptr, c.to(tl.float16), boundary_check=(0, 1))

def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
if transpose_x:
K, M = X.shape
Xstride0, Xstride1 = X.stride(1), X.stride(0)
else:
M, K = X.shape
Xstride0, Xstride1 = X.stride(0), X.stride(1)
if transpose_y:
N, _ = Y.shape
Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
else:
_, N = Y.shape
Wstride0, Wstride1 = Y.stride(0), Y.stride(1)
# Allocates output.
Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )

matmul_kernel_with_block_pointers[grid](X, Y, Z, M, N, K, Xstride0, Xstride1, Wstride0, Wstride1, Z.stride(0),
Z.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=GROUP_SIZE_M)

return Z

M = 512
K = 64
N = 512
dtype = torch.float16
torch.manual_seed(0)

X = torch.randn((M, K) if not TRANSPOSE_A else (K, M), device=device, dtype=dtype, requires_grad=False)
Y = torch.randn((K, N) if not TRANSPOSE_B else (N, K), device=device, dtype=dtype, requires_grad=False)

fn_tor = partial(torch.mm, X if not TRANSPOSE_A else X.T, Y if not TRANSPOSE_B else Y.T)
fn_tri = partial(triton_mm, X, Y, transpose_x=TRANSPOSE_A, transpose_y=TRANSPOSE_B)

result_tor = fn_tor()
result_tri = fn_tri()
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)