Skip to content

Commit 3a8ce9b

Browse files
authored
Unskip some previosly failing tests (#162)
1 parent ce3b6c7 commit 3a8ce9b

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

test/test_loops.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,6 @@ def _fn_make_precompiler(x: torch.Tensor):
370370
return make_precompiler(_fn_kernel)(x, out, out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), a, c, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
371371
)
372372

373-
@unittest.skip("TODO(jansel): fix this")
374373
def test_loop_arg_block(self):
375374
@helion.kernel(config={"block_sizes": [], "indexing": "block_ptr"})
376375
def fn(x: torch.Tensor, block_size: int) -> torch.Tensor:
@@ -386,7 +385,38 @@ def fn(x: torch.Tensor, block_size: int) -> torch.Tensor:
386385
args,
387386
)
388387
torch.testing.assert_close(result, torch.sin(args[0]))
389-
self.assertExpectedInline(code, """""")
388+
self.assertExpectedInline(
389+
code,
390+
"""\
391+
from __future__ import annotations
392+
393+
import torch
394+
import triton
395+
import triton.language as tl
396+
from torch._inductor.runtime.triton_helpers import math as tl_math
397+
398+
@triton.jit
399+
def _fn_kernel(x, out, out_size_0, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
400+
pid_0 = tl.program_id(0)
401+
offset_0 = pid_0 * _BLOCK_SIZE_0
402+
load = tl.load(tl.make_block_ptr(x, [x_size_0], [x_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), boundary_check=[0], padding_option='zero')
403+
v_0 = tl_math.sin(load)
404+
tl.store(tl.make_block_ptr(out, [out_size_0], [out_stride_0], [offset_0], [_BLOCK_SIZE_0], [0]), v_0, boundary_check=[0])
405+
406+
def fn(x: torch.Tensor, block_size: int):
407+
out = torch.empty_like(x)
408+
a, = x.shape
409+
_BLOCK_SIZE_0 = block_size
410+
_fn_kernel[triton.cdiv(a, _BLOCK_SIZE_0),](x, out, out.size(0), x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
411+
return out
412+
413+
def _fn_make_precompiler(x: torch.Tensor, block_size: int):
414+
out = torch.empty_like(x)
415+
a, = x.shape
416+
_BLOCK_SIZE_0 = block_size
417+
from helion.runtime.precompile_shim import make_precompiler
418+
return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
419+
)
390420

391421
def test_three_level_matmul(self):
392422
@helion.kernel(static_shapes=True)

test/test_matmul.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,6 @@ def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
537537
return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
538538
)
539539

540-
@unittest.skip("need to debug correctness issue")
541540
def test_matmul_static_shapes2(self):
542541
args = (
543542
torch.randn([128, 127], device=DEVICE, dtype=torch.float32),
@@ -553,6 +552,8 @@ def test_matmul_static_shapes2(self):
553552
self.assertExpectedInline(
554553
code,
555554
"""\
555+
from __future__ import annotations
556+
556557
import torch
557558
import triton
558559
import triton.language as tl
@@ -568,17 +569,18 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
568569
pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
569570
pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
570571
offset_0 = pid_0 * _BLOCK_SIZE_0
571-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
572+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
572573
offset_1 = pid_1 * _BLOCK_SIZE_1
573-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
574+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
574575
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
575576
for offset_2 in range(0, 127, _BLOCK_SIZE_2):
576577
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
577578
mask_2 = indices_2 < 127
579+
acc_copy = acc
578580
load = tl.load(x + (indices_0[:, None] * 127 + indices_2[None, :] * 1), mask_2[None, :], other=0)
579581
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), mask_2[:, None], other=0)
580582
mm = tl.dot(load, load_1, input_precision='tf32')
581-
acc = acc + mm
583+
acc = acc_copy + mm
582584
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
583585
584586
def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
@@ -590,10 +592,20 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
590592
_BLOCK_SIZE_1 = 16
591593
_BLOCK_SIZE_2 = 16
592594
_matmul_static_shapes_kernel[triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
593-
return out""",
595+
return out
596+
597+
def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
598+
m, k = x.size()
599+
k2, n = y.size()
600+
assert k == k2, f'size mismatch {k} != {k2}'
601+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
602+
_BLOCK_SIZE_0 = 16
603+
_BLOCK_SIZE_1 = 16
604+
_BLOCK_SIZE_2 = 16
605+
from helion.runtime.precompile_shim import make_precompiler
606+
return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
594607
)
595608

596-
@unittest.skip("need to debug correctness issue")
597609
def test_matmul_static_shapes3(self):
598610
args = (
599611
torch.randn([127, 128], device=DEVICE, dtype=torch.float32),
@@ -609,6 +621,8 @@ def test_matmul_static_shapes3(self):
609621
self.assertExpectedInline(
610622
code,
611623
"""\
624+
from __future__ import annotations
625+
612626
import torch
613627
import triton
614628
import triton.language as tl
@@ -624,18 +638,19 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
624638
pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
625639
pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
626640
offset_0 = pid_0 * _BLOCK_SIZE_0
627-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
641+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
628642
mask_0 = indices_0 < 127
629643
offset_1 = pid_1 * _BLOCK_SIZE_1
630-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
644+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
631645
mask_1 = indices_1 < 127
632646
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
633647
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
634648
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
649+
acc_copy = acc
635650
load = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), mask_0[:, None], other=0)
636651
load_1 = tl.load(y + (indices_2[:, None] * 127 + indices_1[None, :] * 1), mask_1[None, :], other=0)
637652
mm = tl.dot(load, load_1, input_precision='tf32')
638-
acc = acc + mm
653+
acc = acc_copy + mm
639654
tl.store(out + (indices_0[:, None] * 127 + indices_1[None, :] * 1), acc, mask_0[:, None] & mask_1[None, :])
640655
641656
def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
@@ -647,7 +662,18 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
647662
_BLOCK_SIZE_1 = 16
648663
_BLOCK_SIZE_2 = 16
649664
_matmul_static_shapes_kernel[triton.cdiv(127, _BLOCK_SIZE_0) * triton.cdiv(127, _BLOCK_SIZE_1),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
650-
return out""",
665+
return out
666+
667+
def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
668+
m, k = x.size()
669+
k2, n = y.size()
670+
assert k == k2, f'size mismatch {k} != {k2}'
671+
out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
672+
_BLOCK_SIZE_0 = 16
673+
_BLOCK_SIZE_1 = 16
674+
_BLOCK_SIZE_2 = 16
675+
from helion.runtime.precompile_shim import make_precompiler
676+
return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
651677
)
652678

653679
def test_matmul_split_k(self):

0 commit comments

Comments
 (0)