Skip to content

Commit befee2d

Browse files
committed
Add autotuning for range() unrolling
1 parent c86b278 commit befee2d

File tree

8 files changed

+198
-13
lines changed

8 files changed

+198
-13
lines changed

helion/_compiler/device_ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ def visit_For(self, node: ast.For) -> None:
529529
)
530530
outputs: LiftTensorArgs | None = None
531531
begin, end = self._extract_tile_begin_end(node)
532+
if (begin is None or isinstance(begin, int)) and isinstance(end, int):
533+
CompileEnvironment.current().config_spec.allow_unroll_loops = True
532534
if isinstance(inner_type, SequenceType):
533535
iter_vars = inner_type.unpack()
534536
if begin is None:

helion/_compiler/generate_ast.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,16 @@ def mask_var(self, block_idx: int) -> str | None:
5656
return loops[-1].strategy.mask_var(block_idx)
5757
return None
5858

59-
def add_statement(self, stmt: ast.AST | str | None) -> None:
59+
def add_statement(self, stmt: ast.AST | list[ast.AST] | str | None) -> None:
6060
if stmt is None:
6161
return
6262
if isinstance(stmt, str):
6363
stmt = statement_from_string(stmt)
64-
self.statements_stack[-1].append(stmt)
64+
if isinstance(stmt, list):
65+
for s in stmt:
66+
self.statements_stack[-1].append(s)
67+
else:
68+
self.statements_stack[-1].append(stmt)
6569

6670
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
6771
return self.device_function.unique_name(prefix, dce=dce)
@@ -116,7 +120,12 @@ def add_device_loop(self, device_loop: DeviceLoopState) -> Iterator[None]:
116120
for idx in device_loop.block_ids:
117121
self.active_device_loops[idx].pop()
118122
self.statements_stack[-1].extend(device_loop.outer_prefix)
119-
self.add_statement(device_loop.for_node)
123+
stmt = device_loop.for_node
124+
if self.device_function.config.unroll_loops:
125+
from .static_loop_unroller import unroll_loop
126+
127+
stmt = unroll_loop(node=device_loop.for_node, allow_range=True)
128+
self.add_statement(stmt)
120129
self.statements_stack[-1].extend(device_loop.outer_suffix)
121130

122131
def set_active_loops(self, device_grid: DeviceLoopOrGridState) -> None:

helion/_compiler/host_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(
103103
from .static_loop_unroller import unroll_static_loops
104104
from .type_propagation import propagate_types
105105

106-
unroll_static_loops(self)
106+
unroll_static_loops(func=self, allow_range=False)
107107
propagate_types(self, fake_args)
108108
env.finalize_config_spec()
109109
self.device_ir = lower_to_device_ir(self)

helion/_compiler/static_loop_unroller.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class StaticLoopUnroller(ast.NodeTransformer):
2323
TODO(oulgen): This pass is primitive, does not handle for.orelse, break, continue etc
2424
"""
2525

26+
def __init__(self, allow_range: bool) -> None:
27+
self.allow_range = allow_range
28+
2629
def visit_For(self, node: ast.For) -> ast.AST | list[ast.AST]:
2730
# Generic visit to handle nested loops
2831
node = self.generic_visit(node) # pyre-ignore[9]
@@ -45,6 +48,35 @@ def _extract_static_values(self, iter_node: ast.expr) -> list[ast.expr] | None:
4548
"""
4649
if isinstance(iter_node, (ast.List, ast.Tuple)):
4750
return iter_node.elts
51+
if (
52+
self.allow_range
53+
and isinstance(iter_node, ast.Call)
54+
and isinstance(iter_node.func, ast.Name)
55+
and iter_node.func.id == "range"
56+
):
57+
range_values = self._extract_range_values(iter_node)
58+
if range_values is not None:
59+
return [create(ast.Constant, value=val) for val in range_values]
60+
61+
return None
62+
63+
def _extract_range_values(self, range_call: ast.Call) -> list[int] | None:
64+
"""
65+
Extract values from a range() call if all arguments are constants.
66+
"""
67+
args = range_call.args
68+
69+
for arg in args:
70+
if not isinstance(arg, ast.Constant) or not isinstance(arg.value, int):
71+
return None
72+
73+
if len(args) == 1:
74+
return list(range(args[0].value)) # pyre-ignore[16]
75+
if len(args) == 2:
76+
return list(range(args[0].value, args[1].value))
77+
if len(args) == 3:
78+
return list(range(args[0].value, args[1].value, args[2].value))
79+
4880
return None
4981

5082
def _unroll_loop(
@@ -68,14 +100,17 @@ def _unroll_loop(
68100
return unrolled_statements
69101

70102

71-
def unroll_static_loops(func: HostFunction) -> None:
72-
new_body = []
103+
def unroll_loop(*, node: ast.AST, allow_range: bool) -> ast.AST | list[ast.AST]:
104+
try:
105+
return StaticLoopUnroller(allow_range).visit(node)
106+
except CannotUnrollLoop:
107+
return node
108+
109+
110+
def unroll_static_loops(*, func: HostFunction, allow_range: bool) -> None:
111+
new_body: list[ast.stmt] = []
73112
for stmt in func.body:
74-
try:
75-
unrolled_stmts = StaticLoopUnroller().visit(stmt)
76-
except CannotUnrollLoop:
77-
new_body.append(stmt)
78-
else:
79-
assert isinstance(unrolled_stmts, ast.stmt)
80-
new_body.append(unrolled_stmts)
113+
maybe_unrolled = unroll_loop(node=stmt, allow_range=allow_range)
114+
assert isinstance(maybe_unrolled, ast.stmt)
115+
new_body.append(maybe_unrolled)
81116
func.body = new_body

helion/autotuner/config_spec.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"num_warps",
4040
"num_stages",
4141
"use_yz_grid",
42+
"unroll_loops",
4243
"indexing",
4344
]
4445
)
@@ -65,6 +66,7 @@ class ConfigSpec:
6566
default_factory=dict
6667
)
6768
allow_use_yz_grid: bool | None = None
69+
allow_unroll_loops: bool | None = None
6870

6971
def _remove_duplicates(self) -> None:
7072
self.loop_orders._remove_duplicates()
@@ -111,6 +113,8 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
111113

112114
if self.allow_use_yz_grid:
113115
config.setdefault("use_yz_grid", False)
116+
if self.allow_unroll_loops:
117+
config.setdefault("unroll_loops", False)
114118

115119
config.setdefault("indexing", "pointer")
116120

@@ -151,6 +155,9 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
151155
not config["flatten_loops"] or not config["flatten_loops"][0]
152156
):
153157
config["use_yz_grid"] = use_yz_grid
158+
if self.allow_unroll_loops:
159+
config["unroll_loops"] = fn(BooleanFragment())
160+
154161
for name in ("loop_orders", "flatten_loops", "reduction_loops", "l2_groupings"):
155162
if not config[name]:
156163
config.pop(name)

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
num_warps: int | None = None,
2929
num_stages: int | None = None,
3030
use_yz_grid: bool | None = None,
31+
unroll_loops: bool | None = None,
3132
indexing: IndexingLiteral | None = None,
3233
# For user-defined properties
3334
**kwargs: object,
@@ -43,6 +44,7 @@ def __init__(
4344
num_warps: Number of warps per block.
4445
num_stages: Number of stages for software pipelining.
4546
use_yz_grid: Whether to use yz grid dimensions.
47+
unroll_loops: Whether to unroll loops.
4648
indexing: Indexing strategy ("pointer", "tensor_descriptor", "block_ptr").
4749
**kwargs: Additional user-defined configuration parameters.
4850
"""
@@ -57,6 +59,7 @@ def __init__(
5759
"num_stages": num_stages,
5860
"indexing": indexing,
5961
"use_yz_grid": use_yz_grid,
62+
"unroll_loops": unroll_loops,
6063
}
6164
for key, value in core_props.items():
6265
if value is not None:
@@ -138,6 +141,10 @@ def l2_groupings(self) -> list[int]:
138141
def use_yz_grid(self) -> bool:
139142
return cast("bool", self.config.get("use_yz_grid", False))
140143

144+
@property
145+
def unroll_loops(self) -> bool:
146+
return cast("bool", self.config.get("unroll_loops", False))
147+
141148
@property
142149
def indexing(self) -> IndexingLiteral:
143150
return self.config.get("indexing", "pointer") # type: ignore

test/test_autotuner.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,34 @@ def test_differential_evolution_search(self):
187187
fn = bound_kernel.compile_config(best)
188188
torch.testing.assert_close(fn(*args), args[0] @ args[1], rtol=1e-2, atol=1e-1)
189189

190+
def test_loop_unroll(self):
191+
@helion.kernel()
192+
def fn(x: torch.Tensor) -> torch.Tensor:
193+
out = torch.zeros_like(x)
194+
for tile in hl.tile(x.size()):
195+
out[tile] = x[tile]
196+
for i in range(1, 4):
197+
out[tile] += i
198+
return out
199+
200+
args = (torch.randn(4, device=DEVICE),)
201+
spec = fn.bind(args).config_spec
202+
configs = ConfigGeneration(spec).random_population(10)
203+
self.assertExpectedInline(
204+
"\n".join(map(repr, configs)),
205+
"""\
206+
helion.Config(block_sizes=[4], num_warps=4, num_stages=3, indexing='pointer', unroll_loops=False)
207+
helion.Config(block_sizes=[2], num_warps=32, num_stages=5, indexing='block_ptr', unroll_loops=True)
208+
helion.Config(block_sizes=[1], num_warps=4, num_stages=4, indexing='block_ptr', unroll_loops=False)
209+
helion.Config(block_sizes=[4], num_warps=2, num_stages=8, indexing='block_ptr', unroll_loops=True)
210+
helion.Config(block_sizes=[1], num_warps=2, num_stages=3, indexing='block_ptr', unroll_loops=True)
211+
helion.Config(block_sizes=[4], num_warps=8, num_stages=5, indexing='pointer', unroll_loops=False)
212+
helion.Config(block_sizes=[4], num_warps=2, num_stages=5, indexing='block_ptr', unroll_loops=False)
213+
helion.Config(block_sizes=[1], num_warps=1, num_stages=4, indexing='block_ptr', unroll_loops=True)
214+
helion.Config(block_sizes=[2], num_warps=32, num_stages=7, indexing='pointer', unroll_loops=True)
215+
helion.Config(block_sizes=[2], num_warps=2, num_stages=2, indexing='pointer', unroll_loops=True)""",
216+
)
217+
190218
def test_use_default_config(self):
191219
@helion.kernel(use_default_config=True)
192220
def add(a, b):

test/test_loops.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,103 @@ def _fn_make_precompiler(x: torch.Tensor):
14991499
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
15001500
)
15011501

1502+
def test_loop_unroll3(self):
1503+
@helion.kernel()
1504+
def fn(x: torch.Tensor) -> torch.Tensor:
1505+
out = torch.zeros_like(x)
1506+
for tile in hl.tile(x.size()):
1507+
out[tile] = x[tile]
1508+
for i in range(1, 4):
1509+
out[tile] += i
1510+
return out
1511+
1512+
x = torch.randn(4, device=DEVICE)
1513+
code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=True)
1514+
torch.testing.assert_close(output, x + 6)
1515+
self.assertExpectedInline(
1516+
code,
1517+
"""\
1518+
from __future__ import annotations
1519+
1520+
import torch
1521+
import triton
1522+
import triton.language as tl
1523+
1524+
@triton.jit
1525+
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1526+
pid_0 = tl.program_id(0)
1527+
offset_0 = pid_0 * _BLOCK_SIZE_0
1528+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1529+
mask_0 = indices_0 < x_size_0
1530+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1531+
tl.store(out + indices_0 * out_stride_0, load, mask_0)
1532+
offset_1 = 1
1533+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1534+
v_0 = offset_1.to(tl.float32)
1535+
v_1 = load_1 + v_0
1536+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1537+
offset_1 = 2
1538+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1539+
v_0 = offset_1.to(tl.float32)
1540+
v_1 = load_1 + v_0
1541+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1542+
offset_1 = 3
1543+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1544+
v_0 = offset_1.to(tl.float32)
1545+
v_1 = load_1 + v_0
1546+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1547+
1548+
def fn(x: torch.Tensor):
1549+
out = torch.zeros_like(x)
1550+
_BLOCK_SIZE_0 = 4
1551+
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1552+
return out
1553+
1554+
def _fn_make_precompiler(x: torch.Tensor):
1555+
out = torch.zeros_like(x)
1556+
_BLOCK_SIZE_0 = 4
1557+
from helion.runtime.precompile_shim import make_precompiler
1558+
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
1559+
)
1560+
1561+
code, output = code_and_output(fn, (x,), block_sizes=[4], unroll_loops=False)
1562+
torch.testing.assert_close(output, x + 6)
1563+
self.assertExpectedInline(
1564+
code,
1565+
"""\
1566+
from __future__ import annotations
1567+
1568+
import torch
1569+
import triton
1570+
import triton.language as tl
1571+
1572+
@triton.jit
1573+
def _fn_kernel(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
1574+
pid_0 = tl.program_id(0)
1575+
offset_0 = pid_0 * _BLOCK_SIZE_0
1576+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1577+
mask_0 = indices_0 < x_size_0
1578+
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
1579+
tl.store(out + indices_0 * out_stride_0, load, mask_0)
1580+
for offset_1 in range(1, 4, 1):
1581+
load_1 = tl.load(out + indices_0 * out_stride_0, mask_0, other=0)
1582+
v_0 = offset_1.to(tl.float32)
1583+
v_1 = load_1 + v_0
1584+
tl.store(out + indices_0 * out_stride_0, v_1, mask_0)
1585+
1586+
def fn(x: torch.Tensor):
1587+
out = torch.zeros_like(x)
1588+
_BLOCK_SIZE_0 = 4
1589+
_fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
1590+
return out
1591+
1592+
def _fn_make_precompiler(x: torch.Tensor):
1593+
out = torch.zeros_like(x)
1594+
_BLOCK_SIZE_0 = 4
1595+
from helion.runtime.precompile_shim import make_precompiler
1596+
return make_precompiler(_fn_kernel)(x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
1597+
)
1598+
15021599

15031600
if __name__ == "__main__":
15041601
unittest.main()

0 commit comments

Comments
 (0)