@@ -537,7 +537,6 @@ def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
537
537
return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""" ,
538
538
)
539
539
540
- @unittest .skip ("need to debug correctness issue" )
541
540
def test_matmul_static_shapes2 (self ):
542
541
args = (
543
542
torch .randn ([128 , 127 ], device = DEVICE , dtype = torch .float32 ),
@@ -553,6 +552,8 @@ def test_matmul_static_shapes2(self):
553
552
self .assertExpectedInline (
554
553
code ,
555
554
"""\
555
+ from __future__ import annotations
556
+
556
557
import torch
557
558
import triton
558
559
import triton.language as tl
@@ -568,17 +569,18 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
568
569
pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
569
570
pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
570
571
offset_0 = pid_0 * _BLOCK_SIZE_0
571
- indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
572
+ indices_0 = ( offset_0 + tl.arange(0, _BLOCK_SIZE_0) ).to(tl.int32)
572
573
offset_1 = pid_1 * _BLOCK_SIZE_1
573
- indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
574
+ indices_1 = ( offset_1 + tl.arange(0, _BLOCK_SIZE_1) ).to(tl.int32)
574
575
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
575
576
for offset_2 in range(0, 127, _BLOCK_SIZE_2):
576
577
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
577
578
mask_2 = indices_2 < 127
579
+ acc_copy = acc
578
580
load = tl.load(x + (indices_0[:, None] * 127 + indices_2[None, :] * 1), mask_2[None, :], other=0)
579
581
load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), mask_2[:, None], other=0)
580
582
mm = tl.dot(load, load_1, input_precision='tf32')
581
- acc = acc + mm
583
+ acc = acc_copy + mm
582
584
tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None)
583
585
584
586
def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
@@ -590,10 +592,20 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
590
592
_BLOCK_SIZE_1 = 16
591
593
_BLOCK_SIZE_2 = 16
592
594
_matmul_static_shapes_kernel[triton.cdiv(128, _BLOCK_SIZE_0) * triton.cdiv(128, _BLOCK_SIZE_1),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
593
- return out""" ,
595
+ return out
596
+
597
+ def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
598
+ m, k = x.size()
599
+ k2, n = y.size()
600
+ assert k == k2, f'size mismatch {k} != {k2}'
601
+ out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
602
+ _BLOCK_SIZE_0 = 16
603
+ _BLOCK_SIZE_1 = 16
604
+ _BLOCK_SIZE_2 = 16
605
+ from helion.runtime.precompile_shim import make_precompiler
606
+ return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""" ,
594
607
)
595
608
596
- @unittest .skip ("need to debug correctness issue" )
597
609
def test_matmul_static_shapes3 (self ):
598
610
args = (
599
611
torch .randn ([127 , 128 ], device = DEVICE , dtype = torch .float32 ),
@@ -609,6 +621,8 @@ def test_matmul_static_shapes3(self):
609
621
self .assertExpectedInline (
610
622
code ,
611
623
"""\
624
+ from __future__ import annotations
625
+
612
626
import torch
613
627
import triton
614
628
import triton.language as tl
@@ -624,18 +638,19 @@ def _matmul_static_shapes_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_
624
638
pid_0 = first_pid_m + tl.program_id(0) % num_pid_in_group % group_size_m
625
639
pid_1 = tl.program_id(0) % num_pid_in_group // group_size_m
626
640
offset_0 = pid_0 * _BLOCK_SIZE_0
627
- indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
641
+ indices_0 = ( offset_0 + tl.arange(0, _BLOCK_SIZE_0) ).to(tl.int32)
628
642
mask_0 = indices_0 < 127
629
643
offset_1 = pid_1 * _BLOCK_SIZE_1
630
- indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
644
+ indices_1 = ( offset_1 + tl.arange(0, _BLOCK_SIZE_1) ).to(tl.int32)
631
645
mask_1 = indices_1 < 127
632
646
acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
633
647
for offset_2 in range(0, 128, _BLOCK_SIZE_2):
634
648
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
649
+ acc_copy = acc
635
650
load = tl.load(x + (indices_0[:, None] * 128 + indices_2[None, :] * 1), mask_0[:, None], other=0)
636
651
load_1 = tl.load(y + (indices_2[:, None] * 127 + indices_1[None, :] * 1), mask_1[None, :], other=0)
637
652
mm = tl.dot(load, load_1, input_precision='tf32')
638
- acc = acc + mm
653
+ acc = acc_copy + mm
639
654
tl.store(out + (indices_0[:, None] * 127 + indices_1[None, :] * 1), acc, mask_0[:, None] & mask_1[None, :])
640
655
641
656
def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
@@ -647,7 +662,18 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
647
662
_BLOCK_SIZE_1 = 16
648
663
_BLOCK_SIZE_2 = 16
649
664
_matmul_static_shapes_kernel[triton.cdiv(127, _BLOCK_SIZE_0) * triton.cdiv(127, _BLOCK_SIZE_1),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
650
- return out""" ,
665
+ return out
666
+
667
+ def _matmul_static_shapes_make_precompiler(x: torch.Tensor, y: torch.Tensor):
668
+ m, k = x.size()
669
+ k2, n = y.size()
670
+ assert k == k2, f'size mismatch {k} != {k2}'
671
+ out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
672
+ _BLOCK_SIZE_0 = 16
673
+ _BLOCK_SIZE_1 = 16
674
+ _BLOCK_SIZE_2 = 16
675
+ from helion.runtime.precompile_shim import make_precompiler
676
+ return make_precompiler(_matmul_static_shapes_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""" ,
651
677
)
652
678
653
679
def test_matmul_split_k (self ):
0 commit comments