Skip to content

Commit 91c07d7

Browse files
authored
Fix failing tests on main (#244)
There was an upstream triton change to min_dot_size.
1 parent b1b474a commit 91c07d7

File tree

4 files changed

+15
-4
lines changed

4 files changed

+15
-4
lines changed

helion/_compat.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@ def torch_dtype_to_tl(torch_dtype: torch.dtype) -> object:
3232
return getattr(tl, name_str)
3333

3434

35-
@functools.cache
3635
def min_dot_size(
3736
device: torch.device, lhs: torch.dtype, rhs: torch.dtype
37+
) -> tuple[int, int, int]:
38+
# call private func we can patch in testing
39+
return _min_dot_size(device, lhs, rhs)
40+
41+
42+
@functools.cache
43+
def _min_dot_size(
44+
device: torch.device, lhs: torch.dtype, rhs: torch.dtype
3845
) -> tuple[int, int, int]:
3946
if device.type != "cuda":
4047
# TODO(jansel): support non-cuda properly

helion/_compiler/inductor_lowering.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,12 +786,12 @@ def apply_dot_requirements(
786786
lshape = lproxy.size()
787787
rshape = rproxy.size()
788788
# use last two dimensions for dot (supports 2D and batched 3D tensors)
789-
n, k = lshape[-2], lshape[-1]
790-
k2, m = rshape[-2], rshape[-1]
789+
m, k = lshape[-2], lshape[-1]
790+
k2, n = rshape[-2], rshape[-1]
791791
assert k == k2, f"Mismatched k dimensions for dot: {k} vs {k2}"
792792
a, b, c = min_dot_size(lproxy.device, lproxy.dtype, rproxy.dtype)
793793
env = CompileEnvironment.current()
794-
for shape, min_size in [(n, a), (k, b), (m, c)]:
794+
for shape, min_size in [(m, a), (n, b), (k, c)]:
795795
block_idx = CompileEnvironment.current().get_block_id(shape)
796796
if block_idx is not None:
797797
env.block_sizes[block_idx].update_min_block(min_size, allow_flattened=True)

test/test_autotuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def setUp(self):
3535
random.seed(112)
3636

3737
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
38+
@patch.object(_compat, "_min_dot_size", lambda *args: (16, 16, 16))
3839
@patch.object(loops, "_supports_warp_specialize", lambda: True)
3940
def test_config_fragment0(self):
4041
args = (

test/test_grid.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import unittest
4+
from unittest.mock import patch
45

56
from expecttest import TestCase
67
import torch
78

89
import helion
10+
from helion import _compat
911
from helion._testing import DEVICE
1012
from helion._testing import code_and_output
1113
import helion.language as hl
@@ -30,6 +32,7 @@ def grid_2d_pytorch(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3032

3133

3234
class TestGrid(TestCase):
35+
@patch.object(_compat, "_min_dot_size", lambda *args: (16, 16, 16))
3336
def test_grid_1d(self):
3437
@helion.kernel(static_shapes=True)
3538
def grid_1d(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)