Skip to content

Commit ce3b6c7

Browse files
authored
Move register_block_size/register_reduction_dim to tunable_ops.py (#161)
1 parent 2561136 commit ce3b6c7

File tree

3 files changed

+142
-136
lines changed

3 files changed

+142
-136
lines changed

helion/language/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from .device_print import device_print as device_print
88
from .loops import Tile as Tile
99
from .loops import grid as grid
10-
from .loops import register_block_size as register_block_size
11-
from .loops import register_reduction_dim as register_reduction_dim
1210
from .loops import tile as tile
1311
from .memory_ops import atomic_add as atomic_add
1412
from .memory_ops import load as load
@@ -17,5 +15,7 @@
1715
from .tiles import tile_block_size as tile_block_size
1816
from .tiles import tile_end as tile_end
1917
from .tiles import tile_index as tile_index
18+
from .tunable_ops import register_block_size as register_block_size
19+
from .tunable_ops import register_reduction_dim as register_reduction_dim
2020
from .tunable_ops import register_tunable as register_tunable
2121
from .view_ops import subscript as subscript

helion/language/loops.py

Lines changed: 1 addition & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@
88
from typing import overload
99

1010
import torch
11-
from torch._inductor.codegen.simd import constant_repr
12-
from torch._inductor.runtime.runtime_utils import next_power_of_2
1311
from torch._inductor.runtime.triton_heuristics import get_max_y_grid
1412

1513
from .. import exc
1614
from .._compiler.ast_extension import ExtendedAST
1715
from .._compiler.ast_extension import LoopType
1816
from .._compiler.ast_extension import expr_from_string
19-
from .._compiler.compile_environment import AutoSize
2017
from .._compiler.compile_environment import CompileEnvironment
2118
from .._compiler.tile_index_proxy import TileIndexProxy
2219
from .._compiler.type_propagation import GridIndexType
@@ -26,7 +23,6 @@
2623
from .._compiler.type_propagation import TileIndexType
2724
from .._compiler.type_propagation import TypeInfo
2825
from .._compiler.type_propagation import UnknownType
29-
from ..autotuner.config_fragment import assert_integer_power_of_two
3026
from ..autotuner.config_spec import ConfigSpec
3127
from ..autotuner.config_spec import FlattenLoopSpec
3228
from ..autotuner.config_spec import L2GroupingSpec
@@ -39,7 +35,7 @@
3935
from .._compiler.inductor_lowering import CodegenState
4036

4137

42-
__all__ = ["Tile", "grid", "register_block_size", "register_reduction_dim", "tile"]
38+
__all__ = ["Tile", "grid", "tile"]
4339
Tile = TileIndexProxy
4440

4541

@@ -372,130 +368,3 @@ def _(state: CodegenState) -> ast.AST:
372368
state.tile_strategy.codegen_grid(state, block_ids)
373369
return expr_from_string("None")
374370
raise AssertionError(f"Expected loop type: {loop_type}")
375-
376-
377-
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
378-
def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> int:
379-
"""
380-
Explicitly register a block size that should be autotuned and can be used for
381-
allocations and inside hl.tile(..., block_size=...).
382-
383-
This is useful if you have two loops where you want them to share a block size,
384-
or if you need to allocate a kernel tensor before the hl.tile() loop.
385-
386-
The signature can one of:
387-
hl.register_block_size(max)
388-
hl.register_block_size(min, max)
389-
390-
Where min and max are integers that control the range of block_sizes searched by
391-
the autotuner. Max may be a symbolic shape, but min must be a constant integer.
392-
"""
393-
raise exc.NotInsideKernel
394-
395-
396-
@_decorators.type_propagation(register_block_size)
397-
def _(
398-
min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin
399-
) -> TypeInfo:
400-
from .._compiler.type_propagation import SymIntType
401-
402-
min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin)
403-
min_proxy = _to_proxy(min_type)
404-
max_proxy = _to_proxy(max_type)
405-
if not isinstance(max_proxy, (int, torch.SymInt)):
406-
raise exc.IncorrectTileUsage(
407-
f"expected max to be an integer or size, got {max_proxy!s}"
408-
)
409-
if not isinstance(min_proxy, int):
410-
raise exc.IncorrectTileUsage(
411-
f"expected min to be an integer constant, got {min_proxy!s}"
412-
)
413-
env = CompileEnvironment.current()
414-
result = TileIndexType.allocate(AutoSize(), origin)
415-
loop_spec = env.config_spec.block_sizes.block_id_lookup(result.block_id)
416-
loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy))
417-
loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy))
418-
block_id = result.block_id
419-
return SymIntType(origin, env.block_sizes[block_id].var)
420-
421-
422-
def _block_id_from_state(state: CodegenState) -> int:
423-
"""Extract the block_id from the current state for nodes hl.register_block_size."""
424-
from .._compiler.type_propagation import SymIntType
425-
426-
env = CompileEnvironment.current()
427-
if state.fx_node is not None:
428-
val = state.fx_node.meta["val"]
429-
assert isinstance(val, SymIntType)
430-
block_id = env.get_block_id(val.value)
431-
assert block_id is not None
432-
return block_id
433-
current_node = ExtendedAST.current()[-1]
434-
type_info = current_node._type_info
435-
assert isinstance(type_info, SymIntType)
436-
block_id = env.get_block_id(type_info.value)
437-
assert block_id is not None
438-
return block_id
439-
440-
441-
@_decorators.codegen(register_block_size)
442-
def _(state: CodegenState) -> ast.AST:
443-
env = CompileEnvironment.current()
444-
block_size = env.config_spec.block_sizes.config_get(
445-
state.config.block_sizes, _block_id_from_state(state)
446-
)
447-
assert block_size is not None
448-
return expr_from_string(constant_repr(block_size))
449-
450-
451-
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
452-
def register_reduction_dim(
453-
size: int,
454-
) -> int:
455-
"""
456-
Explicitly register a reduction dimension that should be used for reduction operations.
457-
458-
This is useful when you need to allocate a dimension for reduction that isn't
459-
automatically inferred from a slice operation. The registered dimension can be
460-
used for allocations and operations that require knowing the reduction size upfront.
461-
462-
:param size: An integer representing the reduction dimension size.
463-
:return: A SymInt object representing the reduction dimension size.
464-
"""
465-
raise exc.NotInsideKernel
466-
467-
468-
@_decorators.type_propagation(register_reduction_dim)
469-
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
470-
from .._compiler.compile_environment import CompileEnvironment
471-
from .._compiler.type_propagation import SymIntType
472-
473-
try:
474-
proxy_sizes = sizes.proxy()
475-
if not isinstance(proxy_sizes, int | torch.SymInt):
476-
raise NotImplementedError
477-
except NotImplementedError:
478-
raise exc.TypePropagationError(
479-
UnknownType(
480-
origin,
481-
f"register_reduction_dim() expected int or list[int], got {sizes!s}",
482-
chained_from=sizes,
483-
)
484-
) from None
485-
486-
env = CompileEnvironment.current()
487-
488-
rdim = env.allocate_reduction_dimension(proxy_sizes)
489-
return SymIntType(origin, rdim.var)
490-
491-
492-
@_decorators.codegen(register_reduction_dim)
493-
def _(state: CodegenState) -> ast.AST:
494-
"""Generate code for register_reduction_dim - return the size expression"""
495-
from .._compiler.type_propagation import SymIntType
496-
497-
current_node = ExtendedAST.current()[-1]
498-
type_info = current_node._type_info
499-
500-
assert isinstance(type_info, SymIntType)
501-
return current_node.args[0] # pyre-ignore[16]

helion/language/tunable_ops.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,160 @@
22

33
from typing import TYPE_CHECKING
44

5+
import torch
56
from torch._inductor.codegen.simd import constant_repr
7+
from torch._inductor.runtime.runtime_utils import next_power_of_2
68

79
from .. import exc
10+
from .._compiler.ast_extension import ExtendedAST
811
from .._compiler.ast_extension import expr_from_string
12+
from .._compiler.compile_environment import AutoSize
13+
from .._compiler.compile_environment import CompileEnvironment
14+
from .._compiler.type_propagation import TileIndexType
15+
from .._compiler.type_propagation import TypeInfo
16+
from .._compiler.type_propagation import UnknownType
17+
from .._compiler.type_propagation import _to_proxy
918
from ..autotuner.config_fragment import ConfigSpecFragment
19+
from ..autotuner.config_fragment import assert_integer_power_of_two
1020
from ..autotuner.config_spec import VALID_KEYS
1121
from ..exc import NotInsideKernel
1222
from . import _decorators
23+
from .loops import _normalize_begin_end
1324

1425
if TYPE_CHECKING:
1526
import ast
1627

1728
from .._compiler.inductor_lowering import CodegenState
18-
from .._compiler.type_propagation import TypeInfo
1929
from .._compiler.variable_origin import Origin
2030

21-
__all__ = ["register_tunable"]
31+
__all__ = ["register_block_size", "register_reduction_dim", "register_tunable"]
32+
33+
34+
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
35+
def register_block_size(min_or_max: int, max_or_none: int | None = None, /) -> int:
36+
"""
37+
Explicitly register a block size that should be autotuned and can be used for
38+
allocations and inside hl.tile(..., block_size=...).
39+
40+
This is useful if you have two loops where you want them to share a block size,
41+
or if you need to allocate a kernel tensor before the hl.tile() loop.
42+
43+
The signature can one of:
44+
hl.register_block_size(max)
45+
hl.register_block_size(min, max)
46+
47+
Where min and max are integers that control the range of block_sizes searched by
48+
the autotuner. Max may be a symbolic shape, but min must be a constant integer.
49+
"""
50+
raise exc.NotInsideKernel
51+
52+
53+
@_decorators.type_propagation(register_block_size)
54+
def _(
55+
min_or_max: TypeInfo, max_or_none: TypeInfo | None = None, /, *, origin: Origin
56+
) -> TypeInfo:
57+
from .._compiler.type_propagation import SymIntType
58+
59+
min_type, max_type = _normalize_begin_end(min_or_max, max_or_none, origin=origin)
60+
min_proxy = _to_proxy(min_type)
61+
max_proxy = _to_proxy(max_type)
62+
if not isinstance(max_proxy, (int, torch.SymInt)):
63+
raise exc.IncorrectTileUsage(
64+
f"expected max to be an integer or size, got {max_proxy!s}"
65+
)
66+
if not isinstance(min_proxy, int):
67+
raise exc.IncorrectTileUsage(
68+
f"expected min to be an integer constant, got {min_proxy!s}"
69+
)
70+
env = CompileEnvironment.current()
71+
result = TileIndexType.allocate(AutoSize(), origin)
72+
loop_spec = env.config_spec.block_sizes.block_id_lookup(result.block_id)
73+
loop_spec.min_size = assert_integer_power_of_two(max(1, min_proxy))
74+
loop_spec.max_size = next_power_of_2(env.size_hint(max_proxy))
75+
block_id = result.block_id
76+
return SymIntType(origin, env.block_sizes[block_id].var)
77+
78+
79+
def _block_id_from_state(state: CodegenState) -> int:
80+
"""Extract the block_id from the current state for nodes hl.register_block_size."""
81+
from .._compiler.type_propagation import SymIntType
82+
83+
env = CompileEnvironment.current()
84+
if state.fx_node is not None:
85+
val = state.fx_node.meta["val"]
86+
assert isinstance(val, SymIntType)
87+
block_id = env.get_block_id(val.value)
88+
assert block_id is not None
89+
return block_id
90+
current_node = ExtendedAST.current()[-1]
91+
type_info = current_node._type_info
92+
assert isinstance(type_info, SymIntType)
93+
block_id = env.get_block_id(type_info.value)
94+
assert block_id is not None
95+
return block_id
96+
97+
98+
@_decorators.codegen(register_block_size)
99+
def _(state: CodegenState) -> ast.AST:
100+
env = CompileEnvironment.current()
101+
block_size = env.config_spec.block_sizes.config_get(
102+
state.config.block_sizes, _block_id_from_state(state)
103+
)
104+
assert block_size is not None
105+
return expr_from_string(constant_repr(block_size))
106+
107+
108+
@_decorators.api(is_device_only=False, cache_type=True, tiles_as_sizes=True)
109+
def register_reduction_dim(
110+
size: int,
111+
) -> int:
112+
"""
113+
Explicitly register a reduction dimension that should be used for reduction operations.
114+
115+
This is useful when you need to allocate a dimension for reduction that isn't
116+
automatically inferred from a slice operation. The registered dimension can be
117+
used for allocations and operations that require knowing the reduction size upfront.
118+
119+
:param size: An integer representing the reduction dimension size.
120+
:return: A SymInt object representing the reduction dimension size.
121+
"""
122+
raise exc.NotInsideKernel
123+
124+
125+
@_decorators.type_propagation(register_reduction_dim)
126+
def _(sizes: TypeInfo, *, origin: Origin) -> TypeInfo:
127+
from .._compiler.compile_environment import CompileEnvironment
128+
from .._compiler.type_propagation import SymIntType
129+
130+
try:
131+
proxy_sizes = sizes.proxy()
132+
if not isinstance(proxy_sizes, int | torch.SymInt):
133+
raise NotImplementedError
134+
except NotImplementedError:
135+
raise exc.TypePropagationError(
136+
UnknownType(
137+
origin,
138+
f"register_reduction_dim() expected int or list[int], got {sizes!s}",
139+
chained_from=sizes,
140+
)
141+
) from None
142+
143+
env = CompileEnvironment.current()
144+
145+
rdim = env.allocate_reduction_dimension(proxy_sizes)
146+
return SymIntType(origin, rdim.var)
147+
148+
149+
@_decorators.codegen(register_reduction_dim)
150+
def _(state: CodegenState) -> ast.AST:
151+
"""Generate code for register_reduction_dim - return the size expression"""
152+
from .._compiler.type_propagation import SymIntType
153+
154+
current_node = ExtendedAST.current()[-1]
155+
type_info = current_node._type_info
156+
157+
assert isinstance(type_info, SymIntType)
158+
return current_node.args[0] # pyre-ignore[16]
22159

23160

24161
@_decorators.api(is_device_only=False)

0 commit comments

Comments
 (0)