Skip to content

Commit 2e7c8bb

Browse files
committed
Add static_range
stack-info: PR: #235, branch: joydddd/stack/9
1 parent 0996865 commit 2e7c8bb

17 files changed

+380
-105
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
216216
Only available on CUDA devices with Blackwell or newer architectures
217217
when `allow_warp_specialize` setting is enabled.
218218

219+
* **static\_ranges** (`list[bool]`):
220+
Contains one entry per loop dimension with static bounds, controlling whether to use
221+
`tl.static_range()` calls. `True` generates `tl.static_range()` and ignores range_* configs for that loop. `False` generates `tl.range()`.
222+
219223
* **reduction\_loops** (`list[int | None]`):
220224
Contains one entry per reduction dimension (see
221225
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/reduction_strategy.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,16 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
253253
)
254254
)
255255

256-
range_extra = self.get_tl_range_kwargs(state, self.block_index)
257256
for_node = create(
258257
ast.For,
259258
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
260259
iter=expr_from_string(
261-
f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})"
260+
self.get_range_call_str(
261+
state,
262+
[self.block_index],
263+
end=state.sympy_expr(numel),
264+
step=block_size_var,
265+
),
262266
),
263267
body=body,
264268
orelse=[],

helion/_compiler/tile_strategy.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def mask_var(self, block_idx: int) -> str | None:
123123
def block_size_var(self, block_idx: int) -> str | None:
124124
return self.fn.block_size_var_cache.get((block_idx,))
125125

126-
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
126+
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> list[str]:
127127
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
128128
env = CompileEnvironment.current()
129129
kwargs = []
@@ -157,10 +157,37 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
157157
)
158158
if range_flatten is not None:
159159
kwargs.append(f"flatten={range_flatten}")
160+
return kwargs
160161

161-
if kwargs:
162-
return f", {', '.join(kwargs)}"
163-
return ""
162+
def get_range_call_str(
163+
self,
164+
state: CodegenState,
165+
block_ids: list[int],
166+
*,
167+
begin: str | None = None,
168+
end: str,
169+
step: str | None = None,
170+
) -> str:
171+
env = CompileEnvironment.current()
172+
use_static_range = all(
173+
env.config_spec.static_ranges.config_get(
174+
state.config.static_ranges, block_idx, None
175+
)
176+
is True
177+
for block_idx in block_ids
178+
)
179+
180+
range_args = []
181+
if begin is not None:
182+
range_args.append(begin)
183+
range_args.append(end)
184+
if step is not None:
185+
range_args.append(f"step={step}")
186+
187+
if use_static_range:
188+
return f"tl.static_range({', '.join(range_args)})"
189+
range_kwargs = self.get_tl_range_kwargs(state, block_ids[0])
190+
return f"tl.range({', '.join(range_args + range_kwargs)})"
164191

165192
def user_size(self, block_index: int) -> sympy.Expr:
166193
raise NotImplementedError
@@ -399,12 +426,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
399426
)
400427
dtype = CompileEnvironment.current().triton_index_type()
401428
lid = self.new_var("lid")
402-
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
429+
end_var = f"tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var})"
403430
for_node = create(
404431
ast.For,
405432
target=create(ast.Name, id=lid, ctx=ast.Store()),
406433
iter=expr_from_string(
407-
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
434+
self.get_range_call_str(state, self.block_ids, end=end_var)
408435
),
409436
body=(
410437
body := [
@@ -609,12 +636,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
609636
end_expr=self._fold_tile_end_op(state, proxy_end, block_size),
610637
)
611638

612-
range_extra = self.get_tl_range_kwargs(state, block_idx)
613639
for_node = create(
614640
ast.For,
615641
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
616642
iter=expr_from_string(
617-
f"tl.range(begin, end, {block_size_var}{range_extra})",
643+
self.get_range_call_str(
644+
state,
645+
[block_idx],
646+
begin="begin",
647+
end="end",
648+
step=block_size_var,
649+
),
618650
begin=self._to_ast(begin, to_dtype=dtype),
619651
end=self._to_ast(end, to_dtype=dtype),
620652
),

helion/autotuner/block_id_sequence.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
110110
"""Return the index of the block_id in the config."""
111111
return self._data[self._block_id_to_index[block_id]]
112112

113+
def valid_block_ids(self) -> list[int]:
114+
"""Return the list of valid block_ids."""
115+
return list(self._block_id_to_index.keys())
116+
113117
def disable_block_id(self, block_id: int) -> None:
114118
"""Remove configuration choice for the given block_id."""
115119
self._data = [x for x in self._data if block_id not in x.block_ids]
@@ -132,6 +136,24 @@ def _flat_config(
132136
"""Map a flattened version of the config using the given function."""
133137
return [spec._flat_config(base, fn) for spec in self._data]
134138

139+
def _reset_to_default(
140+
self, name: str, values: object, *, block_ids: list[int] | None = None
141+
) -> list[object]:
142+
"""Set the config values to the default values. If block_ids is provided, only set those values."""
143+
if not values:
144+
return []
145+
assert isinstance(values, list)
146+
assert len(values) == len(self)
147+
148+
if block_ids is None:
149+
block_ids = self.valid_block_ids()
150+
for block_id in block_ids:
151+
if block_id not in self._block_id_to_index:
152+
continue
153+
index = self._block_id_to_index[block_id]
154+
values[index] = self._data[index]._fill_missing()
155+
return values
156+
135157
def _normalize(
136158
self, name: str, values: object, *, flatten: bool = False
137159
) -> list[object]:

helion/autotuner/config_spec.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"range_num_stages",
4242
"range_multi_buffers",
4343
"range_flattens",
44+
"static_ranges",
4445
"num_warps",
4546
"num_stages",
4647
"use_yz_grid",
@@ -81,6 +82,9 @@ class ConfigSpec:
8182
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
8283
default_factory=BlockIdSequence
8384
)
85+
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field(
86+
default_factory=BlockIdSequence
87+
)
8488
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
8589
default_factory=dict
8690
)
@@ -95,6 +99,7 @@ def _remove_duplicates(self) -> None:
9599
self.range_num_stages._remove_duplicates()
96100
self.range_multi_buffers._remove_duplicates()
97101
self.range_flattens._remove_duplicates()
102+
self.static_ranges._remove_duplicates()
98103

99104
def normalize(self, config: helion.Config | dict[str, object]) -> None:
100105
"""Normalize the config to match the block_sizes and validate the config."""
@@ -113,6 +118,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
113118
"range_num_stage",
114119
"range_multi_buffer",
115120
"range_flatten",
121+
"static_range",
116122
):
117123
if name in config:
118124
names = f"{name}s"
@@ -131,11 +137,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
131137
("range_num_stages", self.range_num_stages, True),
132138
("range_multi_buffers", self.range_multi_buffers, True),
133139
("range_flattens", self.range_flattens, True),
140+
("static_ranges", self.static_ranges, True),
134141
]:
135142
config[name] = mapping._normalize(
136143
name, config.get(name, ()), flatten=flatten
137144
)
138145

146+
static_range_block_ids = []
147+
for block_id in self.static_ranges.valid_block_ids():
148+
use_static_range = self.static_ranges.config_get(
149+
config.get("static_ranges", ()), # pyre-ignore[6]
150+
block_id,
151+
)
152+
if use_static_range:
153+
static_range_block_ids.append(block_id)
154+
155+
for name, mapping in (
156+
("range_unroll_factors", self.range_unroll_factors),
157+
("range_warp_specializes", self.range_warp_specialize),
158+
("range_num_stages", self.range_num_stages),
159+
("range_multi_buffers", self.range_multi_buffers),
160+
("range_flattens", self.range_flattens),
161+
):
162+
config[name] = mapping._reset_to_default(
163+
name, config.get(name, ()), block_ids=static_range_block_ids
164+
)
165+
139166
for name in (
140167
"loop_orders",
141168
"l2_groupings",
@@ -146,6 +173,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
146173
"range_num_stages",
147174
"range_multi_buffers",
148175
"range_flattens",
176+
"static_ranges",
149177
):
150178
if not config[name]:
151179
config.pop(name)
@@ -180,6 +208,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
180208
"range_num_stages": self.range_num_stages._flat_config(self, fn),
181209
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
182210
"range_flattens": self.range_flattens._flat_config(self, fn),
211+
"static_ranges": self.static_ranges._flat_config(self, fn),
183212
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
184213
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
185214
"indexing": fn(
@@ -211,6 +240,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
211240
"range_num_stages",
212241
"range_multi_buffers",
213242
"range_flattens",
243+
"static_ranges",
214244
):
215245
if not config[name]:
216246
config.pop(name)
@@ -399,6 +429,20 @@ class RangeFlattenSpec(_OptionalBoolSpec):
399429
pass
400430

401431

432+
class StaticRangeSpec(_BlockIdItem):
433+
def _fragment(self, base: ConfigSpec) -> BooleanFragment:
434+
return BooleanFragment()
435+
436+
def _normalize(self, name: str, value: object) -> bool:
437+
if not isinstance(value, bool):
438+
raise InvalidConfig(f"{name} must be a boolean, got {value!r}")
439+
return value
440+
441+
def _fill_missing(self) -> bool:
442+
"""Provide a value when not provided by the user."""
443+
return False
444+
445+
402446
def _product(seq: Sequence[int]) -> int:
403447
"""Return the product of the elements in the sequence."""
404448
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..autotuner.config_spec import RangeNumStagesSpec
3434
from ..autotuner.config_spec import RangeUnrollFactorSpec
3535
from ..autotuner.config_spec import RangeWarpSpecializeSpec
36+
from ..autotuner.config_spec import StaticRangeSpec
3637
from . import _decorators
3738
from helion.language.tile_proxy import Tile
3839

@@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None:
151152
)
152153

153154

155+
def _is_constexpr_int(a: object) -> bool:
156+
"""Check if the arg is specialized."""
157+
return isinstance(a, int)
158+
# TODO(joydddd): render SymInt backed by Int as constexpr.
159+
# Now the specialized constexpr is assigned to a dynamic variable first
160+
# and then used as a variable. However args to static_range must be constexpr.
161+
# e.g.
162+
# hl.specialize(x.size(0))
163+
# for i in hl.grid(x.size(0))
164+
# ->
165+
# symbol_0 = 64
166+
# for i in tl.static_range(symbol_0):
167+
#
168+
# if isinstance(a, torch.SymInt):
169+
# return isinstance(a._sympy_(), sympy.Integer)
170+
171+
154172
def _normalize_begin_end(
155173
begin_or_end: TypeInfo,
156174
end_or_none: TypeInfo | None,
@@ -225,6 +243,10 @@ def _(
225243
[x.block_id for x in results],
226244
is_tile=True,
227245
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
246+
is_static=all(
247+
_is_constexpr_int(x) or x is None
248+
for x in (*proxy_begin, *proxy_end, *proxy_block_size)
249+
),
228250
)
229251
if unpack:
230252
(result,) = results
@@ -234,7 +256,11 @@ def _(
234256

235257

236258
def _add_config_choices(
237-
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
259+
block_ids: list[int],
260+
*,
261+
is_tile: bool = False,
262+
has_begin: bool = False,
263+
is_static: bool = False,
238264
) -> None:
239265
config_spec = CompileEnvironment.current().config_spec
240266

@@ -253,6 +279,8 @@ def _add_config_choices(
253279
else:
254280
params = inspect.signature(triton.language.range).parameters
255281
for block_id in block_ids:
282+
if is_static:
283+
config_spec.static_ranges.append(StaticRangeSpec([block_id]))
256284
if "loop_unroll_factor" in params:
257285
config_spec.range_unroll_factors.append(
258286
RangeUnrollFactorSpec([block_id])
@@ -419,6 +447,10 @@ def _(
419447
[x.block_id for x in results],
420448
is_tile=False,
421449
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
450+
is_static=all(
451+
_is_constexpr_int(x) or x is None
452+
for x in (*proxy_begin, *proxy_end, *proxy_step)
453+
),
422454
)
423455
if unpack:
424456
(result,) = results

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
range_num_stages: list[int] | None = None,
3131
range_multi_buffers: list[bool | None] | None = None,
3232
range_flattens: list[bool | None] | None = None,
33+
static_ranges: list[bool] | None = None,
3334
num_warps: int | None = None,
3435
num_stages: int | None = None,
3536
use_yz_grid: bool | None = None,
@@ -50,6 +51,7 @@ def __init__(
5051
range_num_stages: Number of stages for tl.range calls.
5152
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
5253
range_flattens: Controls flatten parameter for tl.range calls.
54+
static_ranges: Whether to use tl.static_range instead tl.range.
5355
num_warps: Number of warps per block.
5456
num_stages: Number of stages for software pipelining.
5557
use_yz_grid: Whether to use yz grid dimensions.
@@ -68,6 +70,7 @@ def __init__(
6870
"range_num_stages": range_num_stages,
6971
"range_multi_buffers": range_multi_buffers,
7072
"range_flattens": range_flattens,
73+
"static_ranges": static_ranges,
7174
"num_warps": num_warps,
7275
"num_stages": num_stages,
7376
"indexing": indexing,
@@ -173,6 +176,10 @@ def range_multi_buffers(self) -> list[bool | None]:
173176
def range_flattens(self) -> list[bool | None]:
174177
return cast("list[bool | None]", self.config.get("range_flattens", []))
175178

179+
@property
180+
def static_ranges(self) -> list[bool]:
181+
return cast("list[bool]", self.config.get("static_ranges", []))
182+
176183
@property
177184
def indexing(self) -> IndexingLiteral:
178185
return self.config.get("indexing", "pointer") # type: ignore

0 commit comments

Comments
 (0)