Skip to content

Commit 1dd3759

Browse files
committed
Add static_range
stack-info: PR: #235, branch: joydddd/stack/9
1 parent 902741b commit 1dd3759

19 files changed

+386
-121
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
216216
Only available on CUDA devices with Blackwell or newer architectures
217217
when `allow_warp_specialize` setting is enabled.
218218

219+
* **static\_ranges** (`list[bool]`):
220+
Contains one entry per loop dimension with static bounds, controlling whether to use
221+
`tl.static_range()` calls. `True` generates `tl.static_range()` and ignores range_* configs for that loop. `False` generates `tl.range()`.
222+
219223
* **reduction\_loops** (`list[int | None]`):
220224
Contains one entry per reduction dimension (see
221225
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/reduction_strategy.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
253253
)
254254
)
255255

256-
range_extra = self.get_tl_range_kwargs(state, self.block_index)
257256
for_node = create(
258257
ast.For,
259258
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
260259
iter=expr_from_string(
261-
f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})"
260+
self.get_range_call_str(
261+
state,
262+
[self.block_index],
263+
begin="0",
264+
end=state.sympy_expr(numel),
265+
step=block_size_var,
266+
),
262267
),
263268
body=body,
264269
orelse=[],

helion/_compiler/tile_strategy.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def mask_var(self, block_idx: int) -> str | None:
125125
def block_size_var(self, block_idx: int) -> str | None:
126126
return self.fn.block_size_var_cache.get((block_idx,))
127127

128-
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
128+
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> list[str]:
129129
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
130130
env = CompileEnvironment.current()
131131
kwargs = []
@@ -159,10 +159,37 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
159159
)
160160
if range_flatten is not None:
161161
kwargs.append(f"flatten={range_flatten}")
162+
return kwargs
162163

163-
if kwargs:
164-
return f", {', '.join(kwargs)}"
165-
return ""
164+
def get_range_call_str(
165+
self,
166+
state: CodegenState,
167+
block_ids: list[int],
168+
*,
169+
begin: str | None = None,
170+
end: str,
171+
step: str | None = None,
172+
) -> str:
173+
env = CompileEnvironment.current()
174+
use_static_range = all(
175+
env.config_spec.static_ranges.config_get(
176+
state.config.static_ranges, block_idx, None
177+
)
178+
is True
179+
for block_idx in block_ids
180+
)
181+
182+
range_args = []
183+
if begin is not None:
184+
range_args.append(begin)
185+
range_args.append(end)
186+
if step is not None:
187+
range_args.append(f"step={step}")
188+
189+
if use_static_range:
190+
return f"tl.static_range({', '.join(range_args)})"
191+
range_kwargs = self.get_tl_range_kwargs(state, block_ids[0])
192+
return f"tl.range({', '.join(range_args + range_kwargs)})"
166193

167194
def user_size(self, block_index: int) -> sympy.Expr:
168195
raise NotImplementedError
@@ -407,12 +434,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
407434
)
408435
dtype = CompileEnvironment.current().triton_index_type()
409436
lid = self.new_var("lid")
410-
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
437+
end_var = f"tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var})"
411438
for_node = create(
412439
ast.For,
413440
target=create(ast.Name, id=lid, ctx=ast.Store()),
414441
iter=expr_from_string(
415-
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
442+
self.get_range_call_str(state, self.block_ids, end=end_var)
416443
),
417444
body=(
418445
body := [
@@ -624,12 +651,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
624651
end_expr=self._fold_tile_end_op(state, proxy_end, block_size),
625652
)
626653

627-
range_extra = self.get_tl_range_kwargs(state, block_idx)
628654
for_node = create(
629655
ast.For,
630656
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
631657
iter=expr_from_string(
632-
f"tl.range(begin, end, {block_size_var}{range_extra})",
658+
self.get_range_call_str(
659+
state,
660+
[block_idx],
661+
begin="begin",
662+
end="end",
663+
step=block_size_var,
664+
),
633665
begin=self._to_ast(begin, to_dtype=dtype),
634666
end=self._to_ast(end, to_dtype=dtype),
635667
),

helion/autotuner/block_id_sequence.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
110110
"""Return the index of the block_id in the config."""
111111
return self._data[self._block_id_to_index[block_id]]
112112

113+
def valid_block_ids(self) -> list[int]:
114+
"""Return the list of valid block_ids."""
115+
return list(self._block_id_to_index.keys())
116+
113117
def disable_block_id(self, block_id: int) -> None:
114118
"""Remove configuration choice for the given block_id."""
115119
self._data = [x for x in self._data if block_id not in x.block_ids]
@@ -132,6 +136,24 @@ def _flat_config(
132136
"""Map a flattened version of the config using the given function."""
133137
return [spec._flat_config(base, fn) for spec in self._data]
134138

139+
def _reset_config_to_default(
140+
self, name: str, values: object, *, block_ids: list[int] | None = None
141+
) -> list[object]:
142+
"""Set the config values to the default values. If block_ids is provided, only set those values."""
143+
if not values:
144+
return []
145+
assert isinstance(values, list)
146+
assert len(values) == len(self)
147+
148+
if block_ids is None:
149+
block_ids = self.valid_block_ids()
150+
for block_id in block_ids:
151+
if block_id not in self._block_id_to_index:
152+
continue
153+
index = self._block_id_to_index[block_id]
154+
values[index] = self._data[index]._fill_missing()
155+
return values
156+
135157
def _normalize(
136158
self, name: str, values: object, *, flatten: bool = False
137159
) -> list[object]:

helion/autotuner/config_spec.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"range_num_stages",
4545
"range_multi_buffers",
4646
"range_flattens",
47+
"static_ranges",
4748
"num_warps",
4849
"num_stages",
4950
"pid_type",
@@ -85,6 +86,9 @@ class ConfigSpec:
8586
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
8687
default_factory=BlockIdSequence
8788
)
89+
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field(
90+
default_factory=BlockIdSequence
91+
)
8892
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
8993
default_factory=dict
9094
)
@@ -109,6 +113,7 @@ def _remove_duplicates(self) -> None:
109113
self.range_num_stages._remove_duplicates()
110114
self.range_multi_buffers._remove_duplicates()
111115
self.range_flattens._remove_duplicates()
116+
self.static_ranges._remove_duplicates()
112117

113118
def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None:
114119
"""Disallow a pid_type from being used in the config."""
@@ -135,6 +140,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
135140
"range_num_stage",
136141
"range_multi_buffer",
137142
"range_flatten",
143+
"static_range",
138144
):
139145
if name in config:
140146
names = f"{name}s"
@@ -153,11 +159,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
153159
("range_num_stages", self.range_num_stages, True),
154160
("range_multi_buffers", self.range_multi_buffers, True),
155161
("range_flattens", self.range_flattens, True),
162+
("static_ranges", self.static_ranges, True),
156163
]:
157164
config[name] = mapping._normalize(
158165
name, config.get(name, ()), flatten=flatten
159166
)
160167

168+
static_range_block_ids = []
169+
for block_id in self.static_ranges.valid_block_ids():
170+
use_static_range = self.static_ranges.config_get(
171+
config.get("static_ranges", ()), # pyre-ignore[6]
172+
block_id,
173+
)
174+
if use_static_range:
175+
static_range_block_ids.append(block_id)
176+
177+
for name, mapping in (
178+
("range_unroll_factors", self.range_unroll_factors),
179+
("range_warp_specializes", self.range_warp_specialize),
180+
("range_num_stages", self.range_num_stages),
181+
("range_multi_buffers", self.range_multi_buffers),
182+
("range_flattens", self.range_flattens),
183+
):
184+
config[name] = mapping._reset_config_to_default(
185+
name, config.get(name, ()), block_ids=static_range_block_ids
186+
)
187+
161188
for name in (
162189
"loop_orders",
163190
"l2_groupings",
@@ -168,6 +195,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
168195
"range_num_stages",
169196
"range_multi_buffers",
170197
"range_flattens",
198+
"static_ranges",
171199
):
172200
if not config[name]:
173201
config.pop(name)
@@ -209,6 +237,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
209237
"range_num_stages": self.range_num_stages._flat_config(self, fn),
210238
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
211239
"range_flattens": self.range_flattens._flat_config(self, fn),
240+
"static_ranges": self.static_ranges._flat_config(self, fn),
212241
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
213242
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
214243
"indexing": fn(EnumFragment(self._valid_indexing_types())),
@@ -228,6 +257,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
228257
"range_num_stages",
229258
"range_multi_buffers",
230259
"range_flattens",
260+
"static_ranges",
231261
):
232262
if not config[name]:
233263
config.pop(name)
@@ -416,6 +446,20 @@ class RangeFlattenSpec(_OptionalBoolSpec):
416446
pass
417447

418448

449+
class StaticRangeSpec(_BlockIdItem):
450+
def _fragment(self, base: ConfigSpec) -> BooleanFragment:
451+
return BooleanFragment()
452+
453+
def _normalize(self, name: str, value: object) -> bool:
454+
if not isinstance(value, bool):
455+
raise InvalidConfig(f"{name} must be a boolean, got {value!r}")
456+
return value
457+
458+
def _fill_missing(self) -> bool:
459+
"""Provide a value when not provided by the user."""
460+
return False
461+
462+
419463
def _product(seq: Sequence[int]) -> int:
420464
"""Return the product of the elements in the sequence."""
421465
return functools.reduce(operator.mul, seq, 1)

helion/language/loops.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..autotuner.config_spec import RangeNumStagesSpec
3434
from ..autotuner.config_spec import RangeUnrollFactorSpec
3535
from ..autotuner.config_spec import RangeWarpSpecializeSpec
36+
from ..autotuner.config_spec import StaticRangeSpec
3637
from . import _decorators
3738
from helion.language.tile_proxy import Tile
3839

@@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None:
151152
)
152153

153154

155+
def _is_constexpr_int(a: object) -> bool:
156+
"""Check if the arg is specialized."""
157+
return isinstance(a, int)
158+
# TODO(joydddd): render SymInt backed by Int as constexpr.
159+
# Now the specialized constexpr is assigned to a dynamic variable first
160+
# and then used as a variable. However args to static_range must be constexpr.
161+
# e.g.
162+
# hl.specialize(x.size(0))
163+
# for i in hl.grid(x.size(0))
164+
# ->
165+
# symbol_0 = 64
166+
# for i in tl.static_range(symbol_0):
167+
#
168+
# if isinstance(a, torch.SymInt):
169+
# return isinstance(a._sympy_(), sympy.Integer)
170+
171+
154172
def _normalize_begin_end(
155173
begin_or_end: TypeInfo,
156174
end_or_none: TypeInfo | None,
@@ -225,6 +243,10 @@ def _(
225243
[x.block_id for x in results],
226244
is_tile=True,
227245
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
246+
is_static=all(
247+
_is_constexpr_int(x) or x is None
248+
for x in (*proxy_begin, *proxy_end, *proxy_block_size)
249+
),
228250
)
229251
if unpack:
230252
(result,) = results
@@ -234,7 +256,11 @@ def _(
234256

235257

236258
def _add_config_choices(
237-
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
259+
block_ids: list[int],
260+
*,
261+
is_tile: bool = False,
262+
has_begin: bool = False,
263+
is_static: bool = False,
238264
) -> None:
239265
config_spec = CompileEnvironment.current().config_spec
240266

@@ -254,6 +280,8 @@ def _add_config_choices(
254280
else:
255281
params = inspect.signature(triton.language.range).parameters
256282
for block_id in block_ids:
283+
if is_static:
284+
config_spec.static_ranges.append(StaticRangeSpec([block_id]))
257285
if "loop_unroll_factor" in params:
258286
config_spec.range_unroll_factors.append(
259287
RangeUnrollFactorSpec([block_id])
@@ -420,6 +448,10 @@ def _(
420448
[x.block_id for x in results],
421449
is_tile=False,
422450
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
451+
is_static=all(
452+
_is_constexpr_int(x) or x is None
453+
for x in (*proxy_begin, *proxy_end, *proxy_step)
454+
),
423455
)
424456
if unpack:
425457
(result,) = results

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
range_num_stages: list[int] | None = None,
3232
range_multi_buffers: list[bool | None] | None = None,
3333
range_flattens: list[bool | None] | None = None,
34+
static_ranges: list[bool] | None = None,
3435
num_warps: int | None = None,
3536
num_stages: int | None = None,
3637
pid_type: PidTypeLiteral | None = None,
@@ -51,6 +52,7 @@ def __init__(
5152
range_num_stages: Number of stages for tl.range calls.
5253
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
5354
range_flattens: Controls flatten parameter for tl.range calls.
55+
static_ranges: Whether to use tl.static_range instead tl.range.
5456
num_warps: Number of warps per block.
5557
num_stages: Number of stages for software pipelining.
5658
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
@@ -69,6 +71,7 @@ def __init__(
6971
"range_num_stages": range_num_stages,
7072
"range_multi_buffers": range_multi_buffers,
7173
"range_flattens": range_flattens,
74+
"static_ranges": static_ranges,
7275
"num_warps": num_warps,
7376
"num_stages": num_stages,
7477
"indexing": indexing,
@@ -174,6 +177,10 @@ def range_multi_buffers(self) -> list[bool | None]:
174177
def range_flattens(self) -> list[bool | None]:
175178
return cast("list[bool | None]", self.config.get("range_flattens", []))
176179

180+
@property
181+
def static_ranges(self) -> list[bool]:
182+
return cast("list[bool]", self.config.get("static_ranges", []))
183+
177184
@property
178185
def indexing(self) -> IndexingLiteral:
179186
return self.config.get("indexing", "pointer") # type: ignore

0 commit comments

Comments
 (0)