diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 7eff2ad2..eb86de87 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -37,7 +37,6 @@ from .ast_extension import LoopType from .ast_extension import NodeVisitor from .ast_extension import create -from .ast_extension import expr_from_string from .ast_read_writes import ReadWrites from .compile_environment import CompileEnvironment from .host_function import HostFunction @@ -239,13 +238,6 @@ def name(self) -> str: def codegen(self, state: CodegenState) -> list[object]: test = state.ast_arg(0) - test_proxy = state.proxy_arg(0) - if isinstance(test_proxy, torch.Tensor) and test_proxy.numel() == 1: - # Triton does not support `if one_elem_tensor:` but supports `if scalar:`, - # so we need to use tl.sum to extract the scalar. - test_code = ast.unparse(test) - test = expr_from_string(f"tl.sum({test_code})") - args = state.ast_args[2] assert isinstance(args, list) assert all(isinstance(x, ast.AST) for x in args) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 82ff66ba..042f0cba 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -274,9 +274,9 @@ def create( mask_values.setdefault(f"({mask}){expand}") output_idx += 1 else: - expand = tile_strategy.expand_str(output_size, output_idx) + # When the index is a scalar (no BlockSizeOrigin), the corresponding dim is eliminated. val = state.device_function.literal_expr(k) - index_values.append(f"tl.full([1], {val}, {dtype}){expand}") + index_values.append(f"({val})") elif isinstance(k, slice) and str(k) == "slice(None, None, None)": expand = tile_strategy.expand_str(output_size, output_idx) size = fake_value.size(len(index_values)) diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index 8d1c1dbd..efd65294 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -330,7 +330,7 @@ def _fn_kernel(a, out0, out1, out2, a_size_0, a_size_1, a_stride_0, a_stride_1, v_1 = load_2 + subscript tl.store(out1 + (indices_0[:, None] * out1_stride_0 + indices_1[None, :] * out1_stride_1), v_1, mask_0[:, None] & mask_1[None, :]) load_4 = tl.load(a + (indices_0[:, None] * a_stride_0 + indices_1[None, :] * a_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + tl.full([1], idx1, tl.int32)[None, :] * a_stride_1), mask_0[:, None], other=0) + load_5 = tl.load(a + (indices_0[:, None] * a_stride_0 + idx1 * a_stride_1), mask_0[:, None], other=0) v_2 = load_4 + load_5 tl.store(out2 + (indices_0[:, None] * out2_stride_0 + indices_1[None, :] * out2_stride_1), v_2, mask_0[:, None] & mask_1[None, :]) diff --git a/test/test_control_flow.py b/test/test_control_flow.py index a19b9c64..0047342f 100644 --- a/test/test_control_flow.py +++ b/test/test_control_flow.py @@ -86,18 +86,16 @@ def _fn_make_precompiler(x, v): return make_precompiler(_fn_kernel)(x, out, x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), v, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""", ) - def test_if_arg_one_element_tensor(self): + def test_if_arg_indexed_scalar(self): @helion.kernel def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = torch.zeros_like(x) for idx in hl.grid(x.shape[0]): - # Since `y[idx]` is a one-element tensor, comparing it against 0 will also create a one-element tensor. + # Since `y[idx]` is a scalar, comparing it against 0 will also create a scalar. if y[idx] != 0: output[idx] = x[idx] * 2 - if ( - y[idx] == 0 - ): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if. + else: output[idx] = x[idx] return output @@ -123,20 +121,18 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def _fn_kernel(x, y, output, output_stride_0, x_stride_0, y_stride_0): pid_0 = tl.program_id(0) offset_0 = pid_0 - load = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None) + load = tl.load(y + offset_0 * y_stride_0, None) v_0 = tl.full([], 0, tl.int32) v_1 = load != v_0 - if tl.sum(v_1): - load_1 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + if v_1: + load_1 = tl.load(x + offset_0 * x_stride_0, None) v_2 = 2.0 v_3 = load_1 * v_2 - tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, v_3, None) - load_2 = tl.load(y + tl.full([1], offset_0, tl.int32) * y_stride_0, None) - v_4 = tl.full([], 0, tl.int32) - v_5 = load_2 == v_4 - if tl.sum(v_5): - load_3 = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) - tl.store(output + tl.full([1], offset_0, tl.int32) * output_stride_0, load_3, None) + tl.store(output + offset_0 * output_stride_0, v_3, None) + _not = not v_1 + if _not: + load_2 = tl.load(x + offset_0 * x_stride_0, None) + tl.store(output + offset_0 * output_stride_0, load_2, None) def fn(x: torch.Tensor, y: torch.Tensor): output = torch.zeros_like(x) @@ -149,6 +145,33 @@ def _fn_make_precompiler(x: torch.Tensor, y: torch.Tensor): return make_precompiler(_fn_kernel)(x, y, output, output.stride(0), x.stride(0), y.stride(0), num_warps=4, num_stages=3)""", ) + def test_if_arg_tensor_sum(self): + @helion.kernel + def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.zeros_like(x) + + for tile in hl.tile(x.shape[0]): + # Since `y[idx]` is a tensor, comparing it against 0 will also create a tensor. + # if condition must takes a scalar, therefore we call .sum() to reduce the tensor to a scalar. + if (y[tile] != 0).sum(): + output[tile] = x[tile] * 2 + if ( + y[tile] == 0 + ).sum(): # TODO(yf225): `else:` raises MLIR error in Triton, so we use a second if. + output[tile] = x[tile] + + return output + + x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=DEVICE) + y = torch.tensor([0, 1, 0, 1], device=DEVICE, dtype=torch.int32) + expected = torch.tensor([1.0, 4.0, 3.0, 8.0], device=DEVICE) + code, result = code_and_output( + fn, + (x, y), + block_size=1, + ) + torch.testing.assert_close(result, expected) + def test_constant_true(self): @helion.kernel( config={ diff --git a/test/test_examples.py b/test/test_examples.py index d9aed2ce..13bf885f 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1692,11 +1692,11 @@ def test_moe_matmul_ogs(self): def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A_stride_0, A_stride_1, C_stride_0, C_stride_1, W_stride_0, W_stride_1, W_stride_2, expert_token_counts_stride_0, expert_token_offsets_stride_0, sorted_to_orig_token_idx_stride_0, max_T_per_expert, N, K, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 - start = tl.load(expert_token_offsets + tl.full([1], offset_0, tl.int32) * expert_token_offsets_stride_0, None) - num_tokens = tl.load(expert_token_counts + tl.full([1], offset_0, tl.int32) * expert_token_counts_stride_0, None) + start = tl.load(expert_token_offsets + offset_0 * expert_token_offsets_stride_0, None) + num_tokens = tl.load(expert_token_counts + offset_0 * expert_token_counts_stride_0, None) v_0 = tl.full([], 0, tl.int32) v_1 = num_tokens != v_0 - if tl.sum(v_1): + if v_1: num_tokens_copy = num_tokens start_copy = start num_tokens_copy_0 = num_tokens_copy @@ -1729,7 +1729,7 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_ expert_orig_token_indices_copy_0 = expert_orig_token_indices_copy acc_copy_0 = acc_copy A_frag = tl.load(A + (expert_orig_token_indices_copy_0[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0) - W_frag = tl.load(W + (tl.full([1], offset_0, tl.int32)[:, None] * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) + W_frag = tl.load(W + (offset_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0) acc = tl.dot(A_frag, W_frag, acc=acc_copy_0, input_precision='tf32') existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0) view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1]) diff --git a/test/test_grid.py b/test/test_grid.py index 9196e5b5..4b77199b 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -87,11 +87,11 @@ def _grid_1d_kernel(x, y, out, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.co indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None) + load = tl.load(x + (offset_0 * 512 + indices_1[:, None] * 32 + indices_3[None, :] * 1), None) load_1 = tl.load(y + (indices_3[:, None] * 4 + indices_2[None, :] * 1), mask_2[None, :], other=0) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :]) + tl.store(out + (offset_0 * 64 + indices_1[:, None] * 4 + indices_2[None, :] * 1), v_0, mask_2[None, :]) def grid_1d(x: torch.Tensor, y: torch.Tensor): b, m, k = x.size() @@ -225,11 +225,11 @@ def _grid_2d_idx_list_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) + load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) + tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) def grid_2d_idx_list(x: torch.Tensor, y: torch.Tensor): bi, bj, m, k = x.size() @@ -363,11 +363,11 @@ def _grid_2d_idx_nested_kernel(x, y, out, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SI indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) acc_copy = acc acc_copy_0 = acc_copy - load = tl.load(x + (tl.full([1], offset_0, tl.int32)[:, None] * 8192 + tl.full([1], offset_1, tl.int32)[:, None] * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) + load = tl.load(x + (offset_0 * 8192 + offset_1 * 2048 + indices_2[:, None] * 32 + indices_4[None, :] * 1), None) load_1 = tl.load(y + (indices_4[:, None] * 16 + indices_3[None, :] * 1), None) acc = tl.dot(load, load_1, acc=acc_copy_0, input_precision='tf32') v_0 = acc.to(tl.float16) - tl.store(out + (tl.full([1], offset_0, tl.int32)[:, None] * 4096 + tl.full([1], offset_1, tl.int32)[:, None] * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) + tl.store(out + (offset_0 * 4096 + offset_1 * 1024 + indices_2[:, None] * 16 + indices_3[None, :] * 1), v_0, None) def grid_2d_idx_nested(x: torch.Tensor, y: torch.Tensor): bi, bj, m, k = x.size() @@ -425,10 +425,10 @@ def _grid_begin_end_kernel(x, out, out_stride_0, x_stride_0): pid_0 = tl.program_id(0) begin_0 = 2 offset_0 = begin_0 + pid_0 - load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + load = tl.load(x + offset_0 * x_stride_0, None) v_0 = 2.0 v_1 = load * v_0 - tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + tl.store(out + offset_0 * out_stride_0, v_1, None) def grid_begin_end(x: torch.Tensor): n = x.size(0) @@ -475,10 +475,10 @@ def grid_begin_end_step_pytorch(x: torch.Tensor) -> torch.Tensor: def _grid_begin_end_step_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 - load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + load = tl.load(x + offset_0 * x_stride_0, None) v_0 = 2.0 v_1 = load * v_0 - tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + tl.store(out + offset_0 * out_stride_0, v_1, None) def grid_begin_end_step(x: torch.Tensor): n = x.size(0) @@ -527,10 +527,10 @@ def grid_end_step_kwarg_pytorch(x: torch.Tensor) -> torch.Tensor: def _grid_end_step_kwarg_kernel(x, out, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 - load = tl.load(x + tl.full([1], offset_0, tl.int32) * x_stride_0, None) + load = tl.load(x + offset_0 * x_stride_0, None) v_0 = 2.0 v_1 = load * v_0 - tl.store(out + tl.full([1], offset_0, tl.int32) * out_stride_0, v_1, None) + tl.store(out + offset_0 * out_stride_0, v_1, None) def grid_end_step_kwarg(x: torch.Tensor): n = x.size(0) @@ -587,10 +587,10 @@ def _grid_multidim_begin_end_kernel(x, out, out_stride_0, out_stride_1, x_stride offset_0 = begin_0 + pid_0 begin_1 = 1 offset_1 = begin_1 + pid_1 - load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None) + load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None) v_0 = 2.0 v_1 = load * v_0 - tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None) + tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None) def grid_multidim_begin_end(x: torch.Tensor): m, n = x.size() @@ -643,10 +643,10 @@ def _grid_multidim_begin_end_step_kernel(x, out, out_stride_0, out_stride_1, x_s pid_1 = tl.program_id(0) // num_blocks_0 offset_0 = pid_0 * _BLOCK_SIZE_0 offset_1 = pid_1 * _BLOCK_SIZE_1 - load = tl.load(x + (tl.full([1], offset_0, tl.int32) * x_stride_0 + tl.full([1], offset_1, tl.int32) * x_stride_1), None) + load = tl.load(x + (offset_0 * x_stride_0 + offset_1 * x_stride_1), None) v_0 = 2.0 v_1 = load * v_0 - tl.store(out + (tl.full([1], offset_0, tl.int32) * out_stride_0 + tl.full([1], offset_1, tl.int32) * out_stride_1), v_1, None) + tl.store(out + (offset_0 * out_stride_0 + offset_1 * out_stride_1), v_1, None) def grid_multidim_begin_end_step(x: torch.Tensor): m, n = x.size() diff --git a/test/test_masking.py b/test/test_masking.py index 5ec32e56..8e9bf5a1 100644 --- a/test/test_masking.py +++ b/test/test_masking.py @@ -1,5 +1,7 @@ from __future__ import annotations +import unittest + from expecttest import TestCase import torch @@ -332,3 +334,7 @@ def _fn_make_precompiler(x): from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_fn_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_register_tunable.py b/test/test_register_tunable.py index e8a20b56..55d44bdf 100644 --- a/test/test_register_tunable.py +++ b/test/test_register_tunable.py @@ -1,5 +1,7 @@ from __future__ import annotations +import unittest + from expecttest import TestCase import torch @@ -187,7 +189,7 @@ def _fn_kernel(x, partial, partial_stride_0, x_stride_0, m, _BLOCK_SIZE_0: tl.co load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) sum_1 = tl.sum(load, 0) floordiv = triton_helpers.div_floor_integer(offset_0, _BLOCK_SIZE_0) - tl.store(partial + tl.full([1], floordiv, tl.int32) * partial_stride_0, sum_1, None) + tl.store(partial + floordiv * partial_stride_0, sum_1, None) def fn(x: torch.Tensor): m = x.size(0) @@ -317,3 +319,7 @@ def _matmul_split_k_make_precompiler(x: torch.Tensor, y: torch.Tensor): from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_matmul_split_k_kernel)(x, y, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), y.stride(0), y.stride(1), n, k, m, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_0, _BLOCK_SIZE_3, num_warps=16, num_stages=8)""", ) + + +if __name__ == "__main__": + unittest.main()