Skip to content

Commit 8668247

Browse files
committed
Add static_range
stack-info: PR: #235, branch: joydddd/stack/9
1 parent 0996865 commit 8668247

File tree

8 files changed

+240
-13
lines changed

8 files changed

+240
-13
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ Contains one entry per loop dimension, controlling the `flatten`
209209
parameter for `tl.range()` calls. `True` sets `flatten=True`,
210210
`False` sets `flatten=False`, and `None` omits the parameter.
211211

212+
* **static\_ranges** (`list[bool]`):
213+
Contains one entry per loop dimension, controlling whether to use
214+
`tl.static_range()` calls. `True` uses `tl.static_range()`, `False` uses `tl.range()`.
215+
212216
* **range\_warp\_specializes** (`list[bool | None]`):
213217
Contains one entry per loop dimension, controlling the `warp_specialize`
214218
parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,

helion/_compiler/tile_strategy.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ def fn(self) -> DeviceFunction:
111111
assert fn is not None
112112
return fn
113113

114+
def get_range_fn_name(self, state: CodegenState, block_idx: int) -> str:
115+
env = CompileEnvironment.current()
116+
range_static = env.config_spec.static_ranges.config_get(
117+
state.config.static_ranges, block_idx, None
118+
)
119+
if range_static is True:
120+
return "tl.static_range"
121+
return "tl.range"
122+
114123
def offset_var(self, block_idx: int) -> str:
115124
return self.offset_vars[block_idx]
116125

@@ -400,11 +409,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
400409
dtype = CompileEnvironment.current().triton_index_type()
401410
lid = self.new_var("lid")
402411
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
412+
range_fn = self.get_range_fn_name(state, self.block_ids[0])
403413
for_node = create(
404414
ast.For,
405415
target=create(ast.Name, id=lid, ctx=ast.Store()),
406416
iter=expr_from_string(
407-
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
417+
f"{range_fn}(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
408418
),
409419
body=(
410420
body := [
@@ -610,11 +620,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
610620
)
611621

612622
range_extra = self.get_tl_range_kwargs(state, block_idx)
623+
range_fn = self.get_range_fn_name(state, self.block_ids[0])
613624
for_node = create(
614625
ast.For,
615626
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
616627
iter=expr_from_string(
617-
f"tl.range(begin, end, {block_size_var}{range_extra})",
628+
f"{range_fn}(begin, end, {block_size_var}{range_extra})",
618629
begin=self._to_ast(begin, to_dtype=dtype),
619630
end=self._to_ast(end, to_dtype=dtype),
620631
),

helion/autotuner/block_id_sequence.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
110110
"""Return the index of the block_id in the config."""
111111
return self._data[self._block_id_to_index[block_id]]
112112

113+
def valid_block_ids(self) -> list[int]:
114+
"""Return the list of valid block_ids."""
115+
return list(self._block_id_to_index.keys())
116+
113117
def disable_block_id(self, block_id: int) -> None:
114118
"""Remove configuration choice for the given block_id."""
115119
self._data = [x for x in self._data if block_id not in x.block_ids]

helion/autotuner/config_spec.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"range_num_stages",
4242
"range_multi_buffers",
4343
"range_flattens",
44+
"static_ranges",
4445
"num_warps",
4546
"num_stages",
4647
"use_yz_grid",
@@ -81,6 +82,9 @@ class ConfigSpec:
8182
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
8283
default_factory=BlockIdSequence
8384
)
85+
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field(
86+
default_factory=BlockIdSequence
87+
)
8488
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
8589
default_factory=dict
8690
)
@@ -95,6 +99,7 @@ def _remove_duplicates(self) -> None:
9599
self.range_num_stages._remove_duplicates()
96100
self.range_multi_buffers._remove_duplicates()
97101
self.range_flattens._remove_duplicates()
102+
self.static_ranges._remove_duplicates()
98103

99104
def normalize(self, config: helion.Config | dict[str, object]) -> None:
100105
"""Normalize the config to match the block_sizes and validate the config."""
@@ -113,6 +118,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
113118
"range_num_stage",
114119
"range_multi_buffer",
115120
"range_flatten",
121+
"static_range",
116122
):
117123
if name in config:
118124
names = f"{name}s"
@@ -131,11 +137,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
131137
("range_num_stages", self.range_num_stages, True),
132138
("range_multi_buffers", self.range_multi_buffers, True),
133139
("range_flattens", self.range_flattens, True),
140+
("static_ranges", self.static_ranges, True),
134141
]:
135142
config[name] = mapping._normalize(
136143
name, config.get(name, ()), flatten=flatten
137144
)
138145

146+
for block_id in self.static_ranges.valid_block_ids():
147+
use_static_range = self.static_ranges.config_get(
148+
config.get("static_ranges", ()), # pyre-ignore[6]
149+
block_id,
150+
)
151+
152+
if use_static_range:
153+
for name, mapping in (
154+
("range_unroll_factors", self.range_unroll_factors),
155+
("range_warp_specializes", self.range_warp_specialize),
156+
("range_num_stages", self.range_num_stages),
157+
("range_multi_buffers", self.range_multi_buffers),
158+
("range_flattens", self.range_flattens),
159+
):
160+
if config[name]: # The config is non empty
161+
# pyre-ignore[16]
162+
config[name][mapping.block_id_to_index(block_id)] = (
163+
mapping.block_id_lookup(block_id)._fill_missing()
164+
)
165+
139166
for name in (
140167
"loop_orders",
141168
"l2_groupings",
@@ -146,6 +173,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
146173
"range_num_stages",
147174
"range_multi_buffers",
148175
"range_flattens",
176+
"static_ranges",
149177
):
150178
if not config[name]:
151179
config.pop(name)
@@ -180,6 +208,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
180208
"range_num_stages": self.range_num_stages._flat_config(self, fn),
181209
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
182210
"range_flattens": self.range_flattens._flat_config(self, fn),
211+
"static_ranges": self.static_ranges._flat_config(self, fn),
183212
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
184213
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
185214
"indexing": fn(
@@ -211,6 +240,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
211240
"range_num_stages",
212241
"range_multi_buffers",
213242
"range_flattens",
243+
"static_ranges",
214244
):
215245
if not config[name]:
216246
config.pop(name)
@@ -399,6 +429,36 @@ class RangeFlattenSpec(_OptionalBoolSpec):
399429
pass
400430

401431

432+
class StaticRangeSpec(_BlockIdItem):
433+
def __init__(
434+
self,
435+
block_id: int,
436+
is_static: bool,
437+
) -> None:
438+
super().__init__([block_id])
439+
self.is_static = is_static
440+
441+
def _fragment(self, base: ConfigSpec) -> ConfigSpecFragment:
442+
if (
443+
self.is_static
444+
): # Only enable tl.static_range when loop parameters are static
445+
return BooleanFragment()
446+
return EnumFragment((False,))
447+
448+
def _normalize(self, name: str, value: object) -> bool:
449+
if not isinstance(value, bool):
450+
raise InvalidConfig(f"{name} must be a boolean, got {value!r}")
451+
if value is True and self.is_static is False:
452+
raise InvalidConfig(
453+
f"Got {name}=Ture for non-static loop #{self.block_id}\n Do you forget to call hl.specialize() on the loop dim? "
454+
)
455+
return value
456+
457+
def _fill_missing(self) -> bool:
458+
"""Provide a value when not provided by the user."""
459+
return False
460+
461+
402462
def _product(seq: Sequence[int]) -> int:
403463
"""Return the product of the elements in the sequence."""
404464
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..autotuner.config_spec import RangeNumStagesSpec
3434
from ..autotuner.config_spec import RangeUnrollFactorSpec
3535
from ..autotuner.config_spec import RangeWarpSpecializeSpec
36+
from ..autotuner.config_spec import StaticRangeSpec
3637
from . import _decorators
3738
from helion.language.tile_proxy import Tile
3839

@@ -151,6 +152,17 @@ def _check_matching(a: object, b: object) -> None:
151152
)
152153

153154

155+
def _is_specialized_int(a: object) -> bool:
156+
import sympy
157+
158+
"""Check if the arg is specialized."""
159+
if isinstance(a, int):
160+
return True
161+
if isinstance(a, torch.SymInt):
162+
return isinstance(a._sympy_(), sympy.Integer)
163+
return False
164+
165+
154166
def _normalize_begin_end(
155167
begin_or_end: TypeInfo,
156168
end_or_none: TypeInfo | None,
@@ -225,6 +237,10 @@ def _(
225237
[x.block_id for x in results],
226238
is_tile=True,
227239
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
240+
is_static=all(
241+
_is_specialized_int(x) or x is None
242+
for x in (*proxy_begin, *proxy_end, *proxy_block_size)
243+
),
228244
)
229245
if unpack:
230246
(result,) = results
@@ -234,7 +250,11 @@ def _(
234250

235251

236252
def _add_config_choices(
237-
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
253+
block_ids: list[int],
254+
*,
255+
is_tile: bool = False,
256+
has_begin: bool = False,
257+
is_static: bool = False,
238258
) -> None:
239259
config_spec = CompileEnvironment.current().config_spec
240260

@@ -253,6 +273,7 @@ def _add_config_choices(
253273
else:
254274
params = inspect.signature(triton.language.range).parameters
255275
for block_id in block_ids:
276+
config_spec.static_ranges.append(StaticRangeSpec(block_id, is_static))
256277
if "loop_unroll_factor" in params:
257278
config_spec.range_unroll_factors.append(
258279
RangeUnrollFactorSpec([block_id])
@@ -419,6 +440,10 @@ def _(
419440
[x.block_id for x in results],
420441
is_tile=False,
421442
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
443+
is_static=all(
444+
_is_specialized_int(x) or x is None
445+
for x in (*proxy_begin, *proxy_end, *proxy_step)
446+
),
422447
)
423448
if unpack:
424449
(result,) = results

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
range_num_stages: list[int] | None = None,
3131
range_multi_buffers: list[bool | None] | None = None,
3232
range_flattens: list[bool | None] | None = None,
33+
static_ranges: list[bool] | None = None,
3334
num_warps: int | None = None,
3435
num_stages: int | None = None,
3536
use_yz_grid: bool | None = None,
@@ -50,6 +51,7 @@ def __init__(
5051
range_num_stages: Number of stages for tl.range calls.
5152
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
5253
range_flattens: Controls flatten parameter for tl.range calls.
54+
static_ranges: Whether to use tl.static_range instead tl.range.
5355
num_warps: Number of warps per block.
5456
num_stages: Number of stages for software pipelining.
5557
use_yz_grid: Whether to use yz grid dimensions.
@@ -68,6 +70,7 @@ def __init__(
6870
"range_num_stages": range_num_stages,
6971
"range_multi_buffers": range_multi_buffers,
7072
"range_flattens": range_flattens,
73+
"static_ranges": static_ranges,
7174
"num_warps": num_warps,
7275
"num_stages": num_stages,
7376
"indexing": indexing,
@@ -173,6 +176,10 @@ def range_multi_buffers(self) -> list[bool | None]:
173176
def range_flattens(self) -> list[bool | None]:
174177
return cast("list[bool | None]", self.config.get("range_flattens", []))
175178

179+
@property
180+
def static_ranges(self) -> list[bool]:
181+
return cast("list[bool]", self.config.get("static_ranges", []))
182+
176183
@property
177184
def indexing(self) -> IndexingLiteral:
178185
return self.config.get("indexing", "pointer") # type: ignore

test/test_autotuner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ def test_config_fragment0(self):
4646
self.assertExpectedInline(
4747
"\n".join(map(repr, configs)),
4848
"""\
49-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
50-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor')
51-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=2, num_stages=8, indexing='tensor_descriptor')
52-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], num_warps=32, num_stages=2, indexing='tensor_descriptor')
53-
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer')
54-
helion.Config(block_sizes=[256, 128, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor')
55-
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], num_warps=1, num_stages=8, indexing='tensor_descriptor')
56-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor')
57-
helion.Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr')
58-
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=1, indexing='tensor_descriptor')""",
49+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], static_ranges=[False], num_warps=4, num_stages=3, indexing='pointer')
50+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], static_ranges=[True], num_warps=1, num_stages=7, indexing='tensor_descriptor')
51+
helion.Config(block_sizes=[64, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], static_ranges=[False], num_warps=32, num_stages=8, indexing='tensor_descriptor')
52+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=16, num_stages=1, indexing='pointer')
53+
helion.Config(block_sizes=[16, 128, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], static_ranges=[True], num_warps=16, num_stages=7, indexing='pointer')
54+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], static_ranges=[True], num_warps=2, num_stages=3, indexing='tensor_descriptor')
55+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_num_stages=[0], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=16, num_stages=3, indexing='block_ptr')
56+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[2], range_num_stages=[3], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=32, num_stages=5, indexing='pointer')
57+
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], static_ranges=[False], num_warps=8, num_stages=6, indexing='block_ptr')
58+
helion.Config(block_sizes=[16, 16, 32], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[None], static_ranges=[True], num_warps=8, num_stages=5, indexing='tensor_descriptor')""",
5959
)
6060

6161
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

0 commit comments

Comments
 (0)