Skip to content

Commit 0ba5311

Browse files
authored
x[i] returns scalar when i=scalar (#223)
1 parent f363965 commit 0ba5311

File tree

8 files changed

+74
-47
lines changed

8 files changed

+74
-47
lines changed

helion/_compiler/device_ir.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from .ast_extension import LoopType
3838
from .ast_extension import NodeVisitor
3939
from .ast_extension import create
40-
from .ast_extension import expr_from_string
4140
from .ast_read_writes import ReadWrites
4241
from .compile_environment import CompileEnvironment
4342
from .host_function import HostFunction
@@ -239,13 +238,6 @@ def name(self) -> str:
239238
def codegen(self, state: CodegenState) -> list[object]:
240239
test = state.ast_arg(0)
241240

242-
test_proxy = state.proxy_arg(0)
243-
if isinstance(test_proxy, torch.Tensor) and test_proxy.numel() == 1:
244-
# Triton does not support `if one_elem_tensor:` but supports `if scalar:`,
245-
# so we need to use tl.sum to extract the scalar.
246-
test_code = ast.unparse(test)
247-
test = expr_from_string(f"tl.sum({test_code})")
248-
249241
args = state.ast_args[2]
250242
assert isinstance(args, list)
251243
assert all(isinstance(x, ast.AST) for x in args)

helion/_compiler/indexing_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ def create(
274274
mask_values.setdefault(f"({mask}){expand}")
275275
output_idx += 1
276276
else:
277-
expand = tile_strategy.expand_str(output_size, output_idx)
277+
# When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated.
278278
val = state.device_function.literal_expr(k)
279-
index_values.append(f"tl.full([1], {val}, {dtype}){expand}")
279+
index_values.append(f"({val})")
280280
elif isinstance(k, slice) and str(k) == "slice(None, None, None)":
281281
expand = tile_strategy.expand_str(output_size, output_idx)
282282
size = fake_value.size(len(index_values))

test/test_broadcasting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def _fn_kernel(a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1,
330330
v_1 = load_2 + subscript
331331
tl.store(out1 + (indices_0[:, None] * out1_stride_0 + indices_1[None, :] * out1_stride_1), v_1, mask_0[:, None] & mask_1[None, :])
332332
load_4 = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
333-
load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + tl.full([1], idx1, tl.int32)[None, :] * a_stride_1), mask_0[:, None], other=0)
333+
load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + idx1 * a_stride_1), mask_0[:, None], other=0)
334334
v_2 = load_4 + load_5
335335
tl.store(out2 + (indices_0[:, None] * out2_stride_0 + indices_1[None, :] * out2_stride_1), v_2, mask_0[:, None] & mask_1[None, :])
336336

test/test_control_flow.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,16 @@ def _fn_make_precompiler(x, v):
8686
return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
8787
)
8888

89-
def test_if_arg_one_element_tensor(self):
89+
def test_if_arg_indexed_scalar(self):
9090
@helion.kernel
9191
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
9292
output = torch.zeros_like(x)
9393

9494
for idx in hl.grid(x.shape[0]):
95-
# Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor.
95+
# Since `y[idx]` is a scalar, comparing it against 0 will also create a scalar.
9696
if y[idx] != 0:
9797
output[idx] = x[idx] * 2
98-
if (
99-
y[idx] == 0
100-
): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
98+
else:
10199
output[idx] = x[idx]
102100

103101
return output
@@ -123,20 +121,18 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
123121
def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0):
124122
pid_0 = tl.program_id(0)
125123
offset_0 = pid_0
126-
load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
124+
load = tl.load(y + offset_0 * y_stride_0, None)
127125
v_0 = tl.full([], 0, tl.int32)
128126
v_1 = load != v_0
129-
if tl.sum(v_1):
130-
load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
127+
if v_1:
128+
load_1 = tl.load(x + offset_0 * x_stride_0, None)
131129
v_2 = 2.0
132130
v_3 = load_1 * v_2
133-
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None)
134-
load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None)
135-
v_4 = tl.full([], 0, tl.int32)
136-
v_5 = load_2 == v_4
137-
if tl.sum(v_5):
138-
load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
139-
tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None)
131+
tl.store(output + offset_0 * output_stride_0, v_3, None)
132+
_not = not v_1
133+
if _not:
134+
load_2 = tl.load(x + offset_0 * x_stride_0, None)
135+
tl.store(output + offset_0 * output_stride_0, load_2, None)
140136
141137
def fn(x: torch.Tensor, y: torch.Tensor):
142138
output = torch.zeros_like(x)
@@ -149,6 +145,33 @@ def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor):
149145
return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""",
150146
)
151147

148+
def test_if_arg_tensor_sum(self):
149+
@helion.kernel
150+
def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
151+
output = torch.zeros_like(x)
152+
153+
for tile in hl.tile(x.shape[0]):
154+
# Since `y[idx]` is a tensor, comparing it against 0 will also create a tensor.
155+
# if condition must takes a scalar, therefore we call .sum() to reduce the tensor to a scalar.
156+
if (y[tile] != 0).sum():
157+
output[tile] = x[tile] * 2
158+
if (
159+
y[tile] == 0
160+
).sum(): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if.
161+
output[tile] = x[tile]
162+
163+
return output
164+
165+
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=DEVICE)
166+
y = torch.tensor([0, 1, 0, 1], device=DEVICE, dtype=torch.int32)
167+
expected = torch.tensor([1.0, 4.0, 3.0, 8.0], device=DEVICE)
168+
code, result = code_and_output(
169+
fn,
170+
(x, y),
171+
block_size=1,
172+
)
173+
torch.testing.assert_close(result, expected)
174+
152175
def test_constant_true(self):
153176
@helion.kernel(
154177
config={

test/test_examples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,11 +1692,11 @@ def test_moe_matmul_ogs(self):
16921692
def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
16931693
pid_0 = tl.program_id(0)
16941694
offset_0 = pid_0
1695-
start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None)
1696-
num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None)
1695+
start = tl.load(expert_token_offsets + offset_0 * expert_token_offsets_stride_0, None)
1696+
num_tokens = tl.load(expert_token_counts + offset_0 * expert_token_counts_stride_0, None)
16971697
v_0 = tl.full([], 0, tl.int32)
16981698
v_1 = num_tokens != v_0
1699-
if tl.sum(v_1):
1699+
if v_1:
17001700
num_tokens_copy = num_tokens
17011701
start_copy = start
17021702
num_tokens_copy_0 = num_tokens_copy
@@ -1729,7 +1729,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
17291729
expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy
17301730
acc_copy_0 = acc_copy
17311731
A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1732-
W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1732+
W_frag = tl.load(W + (offset_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
17331733
acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32')
17341734
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
17351735
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])

test/test_grid.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co
8787
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
8888
acc_copy = acc
8989
acc_copy_0 = acc_copy
90-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
90+
load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None)
9191
load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0)
9292
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
9393
v_0 = acc.to(tl.float16)
94-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
94+
tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :])
9595
9696
def grid_1d(x: torch.Tensor, y: torch.Tensor):
9797
b, m, k = x.size()
@@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE
225225
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
226226
acc_copy = acc
227227
acc_copy_0 = acc_copy
228-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
228+
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
229229
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
230230
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
231231
v_0 = acc.to(tl.float16)
232-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
232+
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
233233
234234
def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor):
235235
bi, bj, m, k = x.size()
@@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI
363363
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
364364
acc_copy = acc
365365
acc_copy_0 = acc_copy
366-
load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
366+
load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None)
367367
load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None)
368368
acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32')
369369
v_0 = acc.to(tl.float16)
370-
tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
370+
tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None)
371371
372372
def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor):
373373
bi, bj, m, k = x.size()
@@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0):
425425
pid_0 = tl.program_id(0)
426426
begin_0 = 2
427427
offset_0 = begin_0 + pid_0
428-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
428+
load = tl.load(x + offset_0 * x_stride_0, None)
429429
v_0 = 2.0
430430
v_1 = load * v_0
431-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
431+
tl.store(out + offset_0 * out_stride_0, v_1, None)
432432
433433
def grid_begin_end(x: torch.Tensor):
434434
n = x.size(0)
@@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor:
475475
def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
476476
pid_0 = tl.program_id(0)
477477
offset_0 = pid_0 * _BLOCK_SIZE_0
478-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
478+
load = tl.load(x + offset_0 * x_stride_0, None)
479479
v_0 = 2.0
480480
v_1 = load * v_0
481-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
481+
tl.store(out + offset_0 * out_stride_0, v_1, None)
482482
483483
def grid_begin_end_step(x: torch.Tensor):
484484
n = x.size(0)
@@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor:
527527
def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
528528
pid_0 = tl.program_id(0)
529529
offset_0 = pid_0 * _BLOCK_SIZE_0
530-
load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None)
530+
load = tl.load(x + offset_0 * x_stride_0, None)
531531
v_0 = 2.0
532532
v_1 = load * v_0
533-
tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None)
533+
tl.store(out + offset_0 * out_stride_0, v_1, None)
534534
535535
def grid_end_step_kwarg(x: torch.Tensor):
536536
n = x.size(0)
@@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride
587587
offset_0 = begin_0 + pid_0
588588
begin_1 = 1
589589
offset_1 = begin_1 + pid_1
590-
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
590+
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
591591
v_0 = 2.0
592592
v_1 = load * v_0
593-
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
593+
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
594594
595595
def grid_multidim_begin_end(x: torch.Tensor):
596596
m, n = x.size()
@@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s
643643
pid_1 = tl.program_id(0) // num_blocks_0
644644
offset_0 = pid_0 * _BLOCK_SIZE_0
645645
offset_1 = pid_1 * _BLOCK_SIZE_1
646-
load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None)
646+
load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None)
647647
v_0 = 2.0
648648
v_1 = load * v_0
649-
tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None)
649+
tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None)
650650
651651
def grid_multidim_begin_end_step(x: torch.Tensor):
652652
m, n = x.size()

test/test_masking.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -332,3 +334,7 @@ def _fn_make_precompiler(x):
332334
from helion.runtime.precompile_shim import make_precompiler
333335
return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
334336
)
337+
338+
339+
if __name__ == "__main__":
340+
unittest.main()

test/test_register_tunable.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -187,7 +189,7 @@ def _fn_kernel(x, partial, partial_stride_0, x_stride_0, m, _BLOCK_SIZE_0: tl.co
187189
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
188190
sum_1 = tl.sum(load, 0)
189191
floordiv = triton_helpers.div_floor_integer(offset_0, _BLOCK_SIZE_0)
190-
tl.store(partial + tl.full([1], floordiv, tl.int32) * partial_stride_0, sum_1, None)
192+
tl.store(partial + floordiv * partial_stride_0, sum_1, None)
191193
192194
def fn(x: torch.Tensor):
193195
m = x.size(0)
@@ -317,3 +319,7 @@ def _matmul_split_k_make_precompiler(x: torch.Tensor, y: torch.Tensor):
317319
from helion.runtime.precompile_shim import make_precompiler
318320
return make_precompiler(_matmul_split_k_kernel)(x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), n, k, m, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=16, num_stages=8)""",
319321
)
322+
323+
324+
if __name__ == "__main__":
325+
unittest.main()

0 commit comments

Comments
 (0)