Skip to content

Add static_range #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
Only available on CUDA devices with Blackwell or newer architectures
when `allow_warp_specialize` setting is enabled.

* **static\_ranges** (`list[bool]`):
Contains one entry per loop dimension with static bounds, controlling whether to use
`tl.static_range()` calls. `True` generates `tl.static_range()` and ignores range_* configs for that loop. `False` generates `tl.range()`.

* **reduction\_loops** (`list[int | None]`):
Contains one entry per reduction dimension (see
`examples/softmax.py`). Using `None` triggers a persistent reduction,
Expand Down
9 changes: 7 additions & 2 deletions helion/_compiler/reduction_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)
)

range_extra = self.get_tl_range_kwargs(state, self.block_index)
for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(0, ({state.sympy_expr(numel)}), {block_size_var}{range_extra})"
self.get_range_call_str(
state,
[self.block_index],
begin="0",
end=state.sympy_expr(numel),
step=block_size_var,
),
),
body=body,
orelse=[],
Expand Down
48 changes: 40 additions & 8 deletions helion/_compiler/tile_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def mask_var(self, block_idx: int) -> str | None:
def block_size_var(self, block_idx: int) -> str | None:
return self.fn.block_size_var_cache.get((block_idx,))

def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> list[str]:
"""Get the range_extra string for loop unroll factor and num_stages based on config."""
env = CompileEnvironment.current()
kwargs = []
Expand Down Expand Up @@ -159,10 +159,37 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
)
if range_flatten is not None:
kwargs.append(f"flatten={range_flatten}")
return kwargs

if kwargs:
return f", {', '.join(kwargs)}"
return ""
def get_range_call_str(
self,
state: CodegenState,
block_ids: list[int],
*,
begin: str | None = None,
end: str,
step: str | None = None,
) -> str:
env = CompileEnvironment.current()
use_static_range = all(
env.config_spec.static_ranges.config_get(
state.config.static_ranges, block_idx, None
)
is True
for block_idx in block_ids
)

range_args = []
if begin is not None:
range_args.append(begin)
range_args.append(end)
if step is not None:
range_args.append(f"step={step}")

if use_static_range:
return f"tl.static_range({', '.join(range_args)})"
range_kwargs = self.get_tl_range_kwargs(state, block_ids[0])
return f"tl.range({', '.join(range_args + range_kwargs)})"

def user_size(self, block_index: int) -> sympy.Expr:
raise NotImplementedError
Expand Down Expand Up @@ -407,12 +434,12 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
)
dtype = CompileEnvironment.current().triton_index_type()
lid = self.new_var("lid")
range_extra = self.get_tl_range_kwargs(state, self.block_ids[0])
end_var = f"tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var})"
for_node = create(
ast.For,
target=create(ast.Name, id=lid, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(tl.cdiv({state.sympy_expr(total_numel)}, {block_size_var}){range_extra})"
self.get_range_call_str(state, self.block_ids, end=end_var)
),
body=(
body := [
Expand Down Expand Up @@ -624,12 +651,17 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
end_expr=self._fold_tile_end_op(state, proxy_end, block_size),
)

range_extra = self.get_tl_range_kwargs(state, block_idx)
for_node = create(
ast.For,
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
iter=expr_from_string(
f"tl.range(begin, end, {block_size_var}{range_extra})",
self.get_range_call_str(
state,
[block_idx],
begin="begin",
end="end",
step=block_size_var,
),
begin=self._to_ast(begin, to_dtype=dtype),
end=self._to_ast(end, to_dtype=dtype),
),
Expand Down
22 changes: 22 additions & 0 deletions helion/autotuner/block_id_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def block_id_lookup(self, block_id: int) -> _BlockIdItemT:
"""Return the index of the block_id in the config."""
return self._data[self._block_id_to_index[block_id]]

def valid_block_ids(self) -> list[int]:
"""Return the list of valid block_ids."""
return list(self._block_id_to_index.keys())

def disable_block_id(self, block_id: int) -> None:
"""Remove configuration choice for the given block_id."""
self._data = [x for x in self._data if block_id not in x.block_ids]
Expand All @@ -132,6 +136,24 @@ def _flat_config(
"""Map a flattened version of the config using the given function."""
return [spec._flat_config(base, fn) for spec in self._data]

def _reset_config_to_default(
self, name: str, values: object, *, block_ids: list[int] | None = None
) -> list[object]:
"""Set the config values to the default values. If block_ids is provided, only set those values."""
if not values:
return []
assert isinstance(values, list)
assert len(values) == len(self)

if block_ids is None:
block_ids = self.valid_block_ids()
for block_id in block_ids:
if block_id not in self._block_id_to_index:
continue
index = self._block_id_to_index[block_id]
values[index] = self._data[index]._fill_missing()
return values

def _normalize(
self, name: str, values: object, *, flatten: bool = False
) -> list[object]:
Expand Down
44 changes: 44 additions & 0 deletions helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
"num_warps",
"num_stages",
"pid_type",
Expand Down Expand Up @@ -85,6 +86,9 @@ class ConfigSpec:
range_flattens: BlockIdSequence[RangeFlattenSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
static_ranges: BlockIdSequence[StaticRangeSpec] = dataclasses.field(
default_factory=BlockIdSequence
)
user_defined_tunables: dict[str, ConfigSpecFragment] = dataclasses.field(
default_factory=dict
)
Expand All @@ -109,6 +113,7 @@ def _remove_duplicates(self) -> None:
self.range_num_stages._remove_duplicates()
self.range_multi_buffers._remove_duplicates()
self.range_flattens._remove_duplicates()
self.static_ranges._remove_duplicates()

def disallow_pid_type(self, pid_type: PidTypeLiteral) -> None:
"""Disallow a pid_type from being used in the config."""
Expand All @@ -135,6 +140,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_num_stage",
"range_multi_buffer",
"range_flatten",
"static_range",
):
if name in config:
names = f"{name}s"
Expand All @@ -153,11 +159,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
("range_num_stages", self.range_num_stages, True),
("range_multi_buffers", self.range_multi_buffers, True),
("range_flattens", self.range_flattens, True),
("static_ranges", self.static_ranges, True),
]:
config[name] = mapping._normalize(
name, config.get(name, ()), flatten=flatten
)

static_range_block_ids = []
for block_id in self.static_ranges.valid_block_ids():
use_static_range = self.static_ranges.config_get(
config.get("static_ranges", ()), # pyre-ignore[6]
block_id,
)
if use_static_range:
static_range_block_ids.append(block_id)

for name, mapping in (
("range_unroll_factors", self.range_unroll_factors),
("range_warp_specializes", self.range_warp_specialize),
("range_num_stages", self.range_num_stages),
("range_multi_buffers", self.range_multi_buffers),
("range_flattens", self.range_flattens),
):
config[name] = mapping._reset_config_to_default(
name, config.get(name, ()), block_ids=static_range_block_ids
)

for name in (
"loop_orders",
"l2_groupings",
Expand All @@ -168,6 +195,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -209,6 +237,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_num_stages": self.range_num_stages._flat_config(self, fn),
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
"range_flattens": self.range_flattens._flat_config(self, fn),
"static_ranges": self.static_ranges._flat_config(self, fn),
"num_warps": fn(NumWarpsFragment(1, 32, DEFAULT_NUM_WARPS)),
"num_stages": fn(IntegerFragment(1, 8, DEFAULT_NUM_STAGES)),
"indexing": fn(EnumFragment(self._valid_indexing_types())),
Expand All @@ -228,6 +257,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
"range_num_stages",
"range_multi_buffers",
"range_flattens",
"static_ranges",
):
if not config[name]:
config.pop(name)
Expand Down Expand Up @@ -416,6 +446,20 @@ class RangeFlattenSpec(_OptionalBoolSpec):
pass


class StaticRangeSpec(_BlockIdItem):
def _fragment(self, base: ConfigSpec) -> BooleanFragment:
return BooleanFragment()

def _normalize(self, name: str, value: object) -> bool:
if not isinstance(value, bool):
raise InvalidConfig(f"{name} must be a boolean, got {value!r}")
return value

def _fill_missing(self) -> bool:
"""Provide a value when not provided by the user."""
return False


def _product(seq: Sequence[int]) -> int:
"""Return the product of the elements in the sequence."""
return functools.reduce(operator.mul, seq, 1)
34 changes: 33 additions & 1 deletion helion/language/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..autotuner.config_spec import RangeNumStagesSpec
from ..autotuner.config_spec import RangeUnrollFactorSpec
from ..autotuner.config_spec import RangeWarpSpecializeSpec
from ..autotuner.config_spec import StaticRangeSpec
from . import _decorators
from helion.language.tile_proxy import Tile

Expand Down Expand Up @@ -151,6 +152,23 @@ def _check_matching(a: object, b: object) -> None:
)


def _is_constexpr_int(a: object) -> bool:
"""Check if the arg is specialized."""
return isinstance(a, int)
# TODO(joydddd): render SymInt backed by Int as constexpr.
# Now the specialized constexpr is assigned to a dynamic variable first
# and then used as a variable. However args to static_range must be constexpr.
# e.g.
# hl.specialize(x.size(0))
# for i in hl.grid(x.size(0))
# ->
# symbol_0 = 64
# for i in tl.static_range(symbol_0):
#
# if isinstance(a, torch.SymInt):
# return isinstance(a._sympy_(), sympy.Integer)


def _normalize_begin_end(
begin_or_end: TypeInfo,
end_or_none: TypeInfo | None,
Expand Down Expand Up @@ -225,6 +243,10 @@ def _(
[x.block_id for x in results],
is_tile=True,
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
is_static=all(
_is_constexpr_int(x) or x is None
for x in (*proxy_begin, *proxy_end, *proxy_block_size)
),
)
if unpack:
(result,) = results
Expand All @@ -234,7 +256,11 @@ def _(


def _add_config_choices(
block_ids: list[int], *, is_tile: bool = False, has_begin: bool = False
block_ids: list[int],
*,
is_tile: bool = False,
has_begin: bool = False,
is_static: bool = False,
) -> None:
config_spec = CompileEnvironment.current().config_spec

Expand All @@ -254,6 +280,8 @@ def _add_config_choices(
else:
params = inspect.signature(triton.language.range).parameters
for block_id in block_ids:
if is_static:
config_spec.static_ranges.append(StaticRangeSpec([block_id]))
if "loop_unroll_factor" in params:
config_spec.range_unroll_factors.append(
RangeUnrollFactorSpec([block_id])
Expand Down Expand Up @@ -420,6 +448,10 @@ def _(
[x.block_id for x in results],
is_tile=False,
has_begin=not all((isinstance(x, int) and x == 0) for x in proxy_begin),
is_static=all(
_is_constexpr_int(x) or x is None
for x in (*proxy_begin, *proxy_end, *proxy_step)
),
)
if unpack:
(result,) = results
Expand Down
7 changes: 7 additions & 0 deletions helion/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
range_num_stages: list[int] | None = None,
range_multi_buffers: list[bool | None] | None = None,
range_flattens: list[bool | None] | None = None,
static_ranges: list[bool] | None = None,
num_warps: int | None = None,
num_stages: int | None = None,
pid_type: PidTypeLiteral | None = None,
Expand All @@ -51,6 +52,7 @@ def __init__(
range_num_stages: Number of stages for tl.range calls.
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
range_flattens: Controls flatten parameter for tl.range calls.
static_ranges: Whether to use tl.static_range instead tl.range.
num_warps: Number of warps per block.
num_stages: Number of stages for software pipelining.
pid_type: Program ID type strategy ("flat", "xyz", "persistent_blocked", "persistent_interleaved").
Expand All @@ -69,6 +71,7 @@ def __init__(
"range_num_stages": range_num_stages,
"range_multi_buffers": range_multi_buffers,
"range_flattens": range_flattens,
"static_ranges": static_ranges,
"num_warps": num_warps,
"num_stages": num_stages,
"indexing": indexing,
Expand Down Expand Up @@ -174,6 +177,10 @@ def range_multi_buffers(self) -> list[bool | None]:
def range_flattens(self) -> list[bool | None]:
return cast("list[bool | None]", self.config.get("range_flattens", []))

@property
def static_ranges(self) -> list[bool]:
return cast("list[bool]", self.config.get("static_ranges", []))

@property
def indexing(self) -> IndexingLiteral:
return self.config.get("indexing", "pointer") # type: ignore
Expand Down
Loading
Loading