From 3ed37595c3852d6b471542bc5a6770ab6f46b2b6 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 24 Jun 2025 14:15:42 -0700 Subject: [PATCH] Add autotuning for range() unrolling stack-info: PR: https://github.com/pytorch-labs/helion/pull/219, branch: oulgen/stack/14 --- helion/_compiler/device_ir.py | 2 + helion/_compiler/generate_ast.py | 15 +++- helion/_compiler/host_function.py | 2 +- helion/_compiler/static_loop_unroller.py | 53 ++++++++++--- helion/autotuner/config_spec.py | 7 ++ helion/runtime/config.py | 7 ++ test/test_autotuner.py | 48 +++++++++--- test/test_loops.py | 97 ++++++++++++++++++++++++ 8 files changed, 208 insertions(+), 23 deletions(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index b59eec64..26b722d9 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -529,6 +529,8 @@ def visit_For(self, node: ast.For) -> None: ) outputs: LiftTensorArgs | None = None begin, end = self._extract_tile_begin_end(node) + if (begin is None or isinstance(begin, int)) and isinstance(end, int): + CompileEnvironment.current().config_spec.allow_unroll_loops = True if isinstance(inner_type, SequenceType): iter_vars = inner_type.unpack() if begin is None: diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 7fd45f82..e132d46a 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -56,12 +56,16 @@ def mask_var(self, block_idx: int) -> str | None: return loops[-1].strategy.mask_var(block_idx) return None - def add_statement(self, stmt: ast.AST | str | None) -> None: + def add_statement(self, stmt: ast.AST | list[ast.AST] | str | None) -> None: if stmt is None: return if isinstance(stmt, str): stmt = statement_from_string(stmt) - self.statements_stack[-1].append(stmt) + if isinstance(stmt, list): + for s in stmt: + self.statements_stack[-1].append(s) + else: + self.statements_stack[-1].append(stmt) def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str: return self.device_function.unique_name(prefix, dce=dce) @@ -116,7 +120,12 @@ def add_device_loop(self, device_loop: DeviceLoopState) -> Iterator[None]: for idx in device_loop.block_ids: self.active_device_loops[idx].pop() self.statements_stack[-1].extend(device_loop.outer_prefix) - self.add_statement(device_loop.for_node) + stmt = device_loop.for_node + if self.device_function.config.unroll_loops: + from .static_loop_unroller import unroll_loop + + stmt = unroll_loop(node=device_loop.for_node, allow_range=True) + self.add_statement(stmt) self.statements_stack[-1].extend(device_loop.outer_suffix) def set_active_loops(self, device_grid: DeviceLoopOrGridState) -> None: diff --git a/helion/_compiler/host_function.py b/helion/_compiler/host_function.py index 6c364a58..11a897c1 100644 --- a/helion/_compiler/host_function.py +++ b/helion/_compiler/host_function.py @@ -103,7 +103,7 @@ def __init__( from .static_loop_unroller import unroll_static_loops from .type_propagation import propagate_types - unroll_static_loops(self) + unroll_static_loops(func=self, allow_range=False) propagate_types(self, fake_args) env.finalize_config_spec() self.device_ir = lower_to_device_ir(self) diff --git a/helion/_compiler/static_loop_unroller.py b/helion/_compiler/static_loop_unroller.py index be3fef0a..1df1a3a3 100644 --- a/helion/_compiler/static_loop_unroller.py +++ b/helion/_compiler/static_loop_unroller.py @@ -23,6 +23,9 @@ class StaticLoopUnroller(ast.NodeTransformer): TODO(oulgen): This pass is primitive, does not handle for.orelse, break, continue etc """ + def __init__(self, allow_range: bool) -> None: + self.allow_range = allow_range + def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]: # Generic visit to handle nested loops node = self.generic_visit(node) # pyre-ignore[9] @@ -45,6 +48,35 @@ def _extract_static_values(self, iter_node: ast.expr) -> list[ast.expr] | None: """ if isinstance(iter_node, (ast.List, ast.Tuple)): return iter_node.elts + if ( + self.allow_range + and isinstance(iter_node, ast.Call) + and isinstance(iter_node.func, ast.Name) + and iter_node.func.id == "range" + ): + range_values = self._extract_range_values(iter_node) + if range_values is not None: + return [create(ast.Constant, value=val) for val in range_values] + + return None + + def _extract_range_values(self, range_call: ast.Call) -> list[int] | None: + """ + Extract values from a range() call if all arguments are constants. + """ + args = range_call.args + + for arg in args: + if not isinstance(arg, ast.Constant) or not isinstance(arg.value, int): + return None + + if len(args) == 1: + return list(range(args[0].value)) # pyre-ignore[16] + if len(args) == 2: + return list(range(args[0].value, args[1].value)) + if len(args) == 3: + return list(range(args[0].value, args[1].value, args[2].value)) + return None def _unroll_loop( @@ -68,14 +100,17 @@ def _unroll_loop( return unrolled_statements -def unroll_static_loops(func: HostFunction) -> None: - new_body = [] +def unroll_loop(*, node: ast.AST, allow_range: bool) -> ast.AST | list[ast.AST]: + try: + return StaticLoopUnroller(allow_range).visit(node) + except CannotUnrollLoop: + return node + + +def unroll_static_loops(*, func: HostFunction, allow_range: bool) -> None: + new_body: list[ast.stmt] = [] for stmt in func.body: - try: - unrolled_stmts = StaticLoopUnroller().visit(stmt) - except CannotUnrollLoop: - new_body.append(stmt) - else: - assert isinstance(unrolled_stmts, ast.stmt) - new_body.append(unrolled_stmts) + maybe_unrolled = unroll_loop(node=stmt, allow_range=allow_range) + assert isinstance(maybe_unrolled, ast.stmt) + new_body.append(maybe_unrolled) func.body = new_body diff --git a/helion/autotuner/config_spec.py b/helion/autotuner/config_spec.py index 13d1e5ae..ae44f970 100644 --- a/helion/autotuner/config_spec.py +++ b/helion/autotuner/config_spec.py @@ -39,6 +39,7 @@ "num_warps", "num_stages", "use_yz_grid", + "unroll_loops", "indexing", ] ) @@ -65,6 +66,7 @@ class ConfigSpec: default_factory=dict ) allow_use_yz_grid: bool | None = None + allow_unroll_loops: bool | None = None def _remove_duplicates(self) -> None: self.loop_orders._remove_duplicates() @@ -111,6 +113,8 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None: if self.allow_use_yz_grid: config.setdefault("use_yz_grid", False) + if self.allow_unroll_loops: + config.setdefault("unroll_loops", False) config.setdefault("indexing", "pointer") @@ -151,6 +155,9 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf not config["flatten_loops"] or not config["flatten_loops"][0] ): config["use_yz_grid"] = use_yz_grid + if self.allow_unroll_loops: + config["unroll_loops"] = fn(BooleanFragment()) + for name in ("loop_orders", "flatten_loops", "reduction_loops", "l2_groupings"): if not config[name]: config.pop(name) diff --git a/helion/runtime/config.py b/helion/runtime/config.py index b7968547..cb6db950 100644 --- a/helion/runtime/config.py +++ b/helion/runtime/config.py @@ -28,6 +28,7 @@ def __init__( num_warps: int | None = None, num_stages: int | None = None, use_yz_grid: bool | None = None, + unroll_loops: bool | None = None, indexing: IndexingLiteral | None = None, # For user-defined properties **kwargs: object, @@ -43,6 +44,7 @@ def __init__( num_warps: Number of warps per block. num_stages: Number of stages for software pipelining. use_yz_grid: Whether to use yz grid dimensions. + unroll_loops: Whether to unroll loops. indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr"). **kwargs: Additional user-defined configuration parameters. """ @@ -57,6 +59,7 @@ def __init__( "num_stages": num_stages, "indexing": indexing, "use_yz_grid": use_yz_grid, + "unroll_loops": unroll_loops, } for key, value in core_props.items(): if value is not None: @@ -138,6 +141,10 @@ def l2_groupings(self) -> list[int]: def use_yz_grid(self) -> bool: return cast("bool", self.config.get("use_yz_grid", False)) + @property + def unroll_loops(self) -> bool: + return cast("bool", self.config.get("unroll_loops", False)) + @property def indexing(self) -> IndexingLiteral: return self.config.get("indexing", "pointer") # type: ignore diff --git a/test/test_autotuner.py b/test/test_autotuner.py index c7b23827..85ed99ae 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -44,16 +44,16 @@ def test_config_fragment0(self): self.assertExpectedInline( "\n".join(map(repr, configs)), """\ -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer') -helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr') -helion.Config(block_sizes=[32, 16, 16], loop_orders=[[1, 0]], l2_groupings=[32], num_warps=8, num_stages=8, indexing='block_ptr') -helion.Config(block_sizes=[16, 16, 32], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=8, num_stages=2, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor') -helion.Config(block_sizes=[32, 128, 64], loop_orders=[[0, 1]], l2_groupings=[2], num_warps=16, num_stages=5, indexing='pointer') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], num_warps=16, num_stages=3, indexing='tensor_descriptor') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=2, indexing='block_ptr') -helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=1, num_stages=1, indexing='tensor_descriptor')""", +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], num_warps=4, num_stages=3, indexing='pointer', unroll_loops=False) +helion.Config(block_sizes=[32, 128, 64], loop_orders=[[1, 0]], l2_groupings=[8], num_warps=32, num_stages=3, indexing='block_ptr', unroll_loops=False) +helion.Config(block_sizes=[128, 16, 128], loop_orders=[[0, 1]], l2_groupings=[8], num_warps=4, num_stages=6, indexing='pointer', unroll_loops=False) +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=7, indexing='tensor_descriptor', unroll_loops=True) +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[8], num_warps=32, num_stages=2, indexing='tensor_descriptor', unroll_loops=False) +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[64], num_warps=4, num_stages=7, indexing='tensor_descriptor', unroll_loops=False) +helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=4, num_stages=4, indexing='tensor_descriptor', unroll_loops=False) +helion.Config(block_sizes=[64, 16, 128], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=32, num_stages=2, indexing='tensor_descriptor', unroll_loops=False) +helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], num_warps=16, num_stages=3, indexing='block_ptr', unroll_loops=False) +helion.Config(block_sizes=[16, 32, 32], loop_orders=[[0, 1]], l2_groupings=[4], num_warps=4, num_stages=7, indexing='block_ptr', unroll_loops=True)""", ) @patch.object(_compat, "_supports_tensor_descriptor", lambda: True) @@ -187,6 +187,34 @@ def test_differential_evolution_search(self): fn = bound_kernel.compile_config(best) torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1) + def test_loop_unroll(self): + @helion.kernel() + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] + for i in range(1, 4): + out[tile] += i + return out + + args = (torch.randn(4, device=DEVICE),) + spec = fn.bind(args).config_spec + configs = ConfigGeneration(spec).random_population(10) + self.assertExpectedInline( + "\n".join(map(repr, configs)), + """\ +helion.Config(block_sizes=[4], num_warps=4, num_stages=3, indexing='pointer', unroll_loops=False) +helion.Config(block_sizes=[2], num_warps=32, num_stages=5, indexing='block_ptr', unroll_loops=True) +helion.Config(block_sizes=[1], num_warps=4, num_stages=4, indexing='block_ptr', unroll_loops=False) +helion.Config(block_sizes=[4], num_warps=2, num_stages=8, indexing='block_ptr', unroll_loops=True) +helion.Config(block_sizes=[1], num_warps=2, num_stages=3, indexing='block_ptr', unroll_loops=True) +helion.Config(block_sizes=[4], num_warps=8, num_stages=5, indexing='pointer', unroll_loops=False) +helion.Config(block_sizes=[4], num_warps=2, num_stages=5, indexing='block_ptr', unroll_loops=False) +helion.Config(block_sizes=[1], num_warps=1, num_stages=4, indexing='block_ptr', unroll_loops=True) +helion.Config(block_sizes=[2], num_warps=32, num_stages=7, indexing='pointer', unroll_loops=True) +helion.Config(block_sizes=[2], num_warps=2, num_stages=2, indexing='pointer', unroll_loops=True)""", + ) + def test_use_default_config(self): @helion.kernel(use_default_config=True) def add(a, b): diff --git a/test/test_loops.py b/test/test_loops.py index 0f01bd5b..c04a0f0d 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -1499,6 +1499,103 @@ def _fn_make_precompiler(x: torch.Tensor): return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", ) + def test_loop_unroll3(self): + @helion.kernel() + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + for tile in hl.tile(x.size()): + out[tile] = x[tile] + for i in range(1, 4): + out[tile] += i + return out + + x = torch.randn(4, device=DEVICE) + code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=True) + torch.testing.assert_close(output, x + 6) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + tl.store(out + indices_0 * out_stride_0, load, mask_0) + offset_1 = 1 + load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + v_0 = offset_1.to(tl.float32) + v_1 = load_1 + v_0 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + offset_1 = 2 + load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + v_0 = offset_1.to(tl.float32) + v_1 = load_1 + v_0 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + offset_1 = 3 + load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + v_0 = offset_1.to(tl.float32) + v_1 = load_1 + v_0 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + +def fn(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _fn_make_precompiler(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", + ) + + code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=False) + torch.testing.assert_close(output, x + 6) + self.assertExpectedInline( + code, + """\ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +@triton.jit +def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + tl.store(out + indices_0 * out_stride_0, load, mask_0) + for offset_1 in range(1, 4, 1): + load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0) + v_0 = offset_1.to(tl.float32) + v_1 = load_1 + v_0 + tl.store(out + indices_0 * out_stride_0, v_1, mask_0) + +def fn(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _fn_make_precompiler(x: torch.Tensor): + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", + ) + if __name__ == "__main__": unittest.main()