Skip to content

Commit 52306b0

Browse files
authored
Add tl.range warp_specialize to autotuner (#230)
1 parent 43faf72 commit 52306b0

File tree

9 files changed

+157
-84
lines changed

9 files changed

+157
-84
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ Contains one entry per loop dimension, controlling the `flatten`
209209
parameter for `tl.range()` calls. `True` sets `flatten=True`,
210210
`False` sets `flatten=False`, and `None` omits the parameter.
211211

212+
* **range\_warp\_specializes** (`list[bool | None]`):
213+
Contains one entry per loop dimension, controlling the `warp_specialize`
214+
parameter for `tl.range()` calls. `True` sets `warp_specialize=True`,
215+
`False` sets `warp_specialize=False`, and `None` omits the parameter.
216+
Only available on CUDA devices with Blackwell or newer architectures
217+
when `allow_warp_specialize` setting is enabled.
218+
212219
* **reduction\_loops** (`list[int | None]`):
213220
Contains one entry per reduction dimension (see
214221
`examples/softmax.py`). Using `None` triggers a persistent reduction,

helion/_compiler/device_function.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,16 @@ def codegen_function_def(self) -> ast.FunctionDef:
359359

360360
def codegen_function_call(self) -> ast.AST:
361361
args = [arg.host_str() for arg in self.sorted_args()]
362+
363+
# Workaround for triton bug: warp_specialize requires at least 4 warps
364+
# See: https://github.com/triton-lang/triton/issues/7354
365+
num_warps = self.config.num_warps
366+
if any(self.config.range_warp_specializes):
367+
num_warps = max(4, num_warps)
368+
362369
args.extend(
363370
[
364-
f"num_warps={self.config.num_warps}",
371+
f"num_warps={num_warps}",
365372
f"num_stages={self.config.num_stages}",
366373
]
367374
)

helion/_compiler/tile_strategy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def get_tl_range_kwargs(self, state: CodegenState, block_idx: int) -> str:
134134
if range_unroll_factor > 0:
135135
kwargs.append(f"loop_unroll_factor={range_unroll_factor}")
136136

137+
range_warp_specialize = env.config_spec.range_warp_specialize.config_get(
138+
state.config.range_warp_specializes, block_idx, None
139+
)
140+
if range_warp_specialize is not None:
141+
kwargs.append(f"warp_specialize={range_warp_specialize}")
142+
137143
range_num_stages = env.config_spec.range_num_stages.config_get(
138144
state.config.range_num_stages, block_idx, 0
139145
)

helion/autotuner/config_spec.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"reduction_loops",
3838
"flatten_loops",
3939
"range_unroll_factors",
40+
"range_warp_specializes",
4041
"range_num_stages",
4142
"range_multi_buffers",
4243
"range_flattens",
@@ -68,6 +69,9 @@ class ConfigSpec:
6869
range_unroll_factors: BlockIdSequence[RangeUnrollFactorSpec] = dataclasses.field(
6970
default_factory=BlockIdSequence
7071
)
72+
range_warp_specialize: BlockIdSequence[RangeWarpSpecializeSpec] = dataclasses.field(
73+
default_factory=BlockIdSequence
74+
)
7175
range_num_stages: BlockIdSequence[RangeNumStagesSpec] = dataclasses.field(
7276
default_factory=BlockIdSequence
7377
)
@@ -87,6 +91,7 @@ def _remove_duplicates(self) -> None:
8791
self.l2_groupings._remove_duplicates()
8892
self.flatten_loops._remove_duplicates()
8993
self.range_unroll_factors._remove_duplicates()
94+
self.range_warp_specialize._remove_duplicates()
9095
self.range_num_stages._remove_duplicates()
9196
self.range_multi_buffers._remove_duplicates()
9297
self.range_flattens._remove_duplicates()
@@ -104,6 +109,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
104109
"l2_grouping",
105110
"flatten_loop",
106111
"range_unroll_factor",
112+
"range_warp_specialize",
107113
"range_num_stage",
108114
"range_multi_buffer",
109115
"range_flatten",
@@ -121,6 +127,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
121127
("loop_orders", self.loop_orders, False),
122128
("reduction_loops", self.reduction_loops, True),
123129
("range_unroll_factors", self.range_unroll_factors, True),
130+
("range_warp_specializes", self.range_warp_specialize, True),
124131
("range_num_stages", self.range_num_stages, True),
125132
("range_multi_buffers", self.range_multi_buffers, True),
126133
("range_flattens", self.range_flattens, True),
@@ -135,6 +142,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
135142
"flatten_loops",
136143
"reduction_loops",
137144
"range_unroll_factors",
145+
"range_warp_specializes",
138146
"range_num_stages",
139147
"range_multi_buffers",
140148
"range_flattens",
@@ -168,6 +176,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
168176
"l2_groupings": self.l2_groupings._flat_config(self, fn),
169177
"reduction_loops": self.reduction_loops._flat_config(self, fn),
170178
"range_unroll_factors": self.range_unroll_factors._flat_config(self, fn),
179+
"range_warp_specializes": self.range_warp_specialize._flat_config(self, fn),
171180
"range_num_stages": self.range_num_stages._flat_config(self, fn),
172181
"range_multi_buffers": self.range_multi_buffers._flat_config(self, fn),
173182
"range_flattens": self.range_flattens._flat_config(self, fn),
@@ -198,6 +207,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
198207
"reduction_loops",
199208
"l2_groupings",
200209
"range_unroll_factors",
210+
"range_warp_specializes",
201211
"range_num_stages",
202212
"range_multi_buffers",
203213
"range_flattens",
@@ -342,24 +352,7 @@ def _fill_missing(self) -> None:
342352
return None
343353

344354

345-
class RangeUnrollFactorSpec(_BlockIdItem):
346-
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
347-
return IntegerFragment(0, 4, 0)
348-
349-
def _normalize(self, name: str, value: object) -> int:
350-
if not isinstance(value, int):
351-
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
352-
return value
353-
354-
def _fill_missing(self) -> int:
355-
"""Provide a value when not provided by the user."""
356-
return 0
357-
358-
359-
class RangeNumStagesSpec(_BlockIdItem):
360-
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
361-
return IntegerFragment(0, 4, 0)
362-
355+
class _OptionalIntSpec(_BlockIdItem):
363356
def _normalize(self, name: str, value: object) -> int:
364357
if not isinstance(value, int):
365358
raise InvalidConfig(f"{name} must be an integer, got {value!r}")
@@ -370,7 +363,7 @@ def _fill_missing(self) -> int:
370363
return 0
371364

372365

373-
class RangeMultiBufferSpec(_BlockIdItem):
366+
class _OptionalBoolSpec(_BlockIdItem):
374367
def _fragment(self, base: ConfigSpec) -> EnumFragment:
375368
return EnumFragment((None, False, True))
376369

@@ -384,18 +377,26 @@ def _fill_missing(self) -> None:
384377
return None
385378

386379

387-
class RangeFlattenSpec(_BlockIdItem):
388-
def _fragment(self, base: ConfigSpec) -> EnumFragment:
389-
return EnumFragment((None, False, True))
380+
class RangeUnrollFactorSpec(_OptionalIntSpec):
381+
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
382+
return IntegerFragment(0, 4, 0)
390383

391-
def _normalize(self, name: str, value: object) -> bool | None:
392-
if value is not None and not isinstance(value, bool):
393-
raise InvalidConfig(f"{name} must be a boolean or None, got {value!r}")
394-
return value
395384

396-
def _fill_missing(self) -> None:
397-
"""Provide a value when not provided by the user."""
398-
return
385+
class RangeWarpSpecializeSpec(_OptionalBoolSpec):
386+
pass
387+
388+
389+
class RangeNumStagesSpec(_OptionalIntSpec):
390+
def _fragment(self, base: ConfigSpec) -> IntegerFragment:
391+
return IntegerFragment(0, 4, 0)
392+
393+
394+
class RangeMultiBufferSpec(_OptionalBoolSpec):
395+
pass
396+
397+
398+
class RangeFlattenSpec(_OptionalBoolSpec):
399+
pass
399400

400401

401402
def _product(seq: Sequence[int]) -> int:

helion/language/loops.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ast
44
import builtins
5+
import inspect
56
from typing import TYPE_CHECKING
67
from typing import Iterator
78
from typing import Sequence
@@ -10,6 +11,7 @@
1011

1112
import torch
1213
from torch._inductor.runtime.triton_heuristics import get_max_y_grid
14+
import triton.language
1315

1416
from .. import exc
1517
from .._compiler.ast_extension import ExtendedAST
@@ -30,6 +32,7 @@
3032
from ..autotuner.config_spec import RangeMultiBufferSpec
3133
from ..autotuner.config_spec import RangeNumStagesSpec
3234
from ..autotuner.config_spec import RangeUnrollFactorSpec
35+
from ..autotuner.config_spec import RangeWarpSpecializeSpec
3336
from . import _decorators
3437
from helion.language.tile_proxy import Tile
3538

@@ -248,11 +251,30 @@ def _add_config_choices(
248251
config_spec.l2_groupings.append(L2GroupingSpec(block_ids))
249252
config_spec.allow_use_yz_grid = _allow_use_yz_grid(config_spec, block_ids)
250253
else:
254+
params = inspect.signature(triton.language.range).parameters
251255
for block_id in block_ids:
252-
config_spec.range_unroll_factors.append(RangeUnrollFactorSpec([block_id]))
253-
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
254-
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
255-
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))
256+
if "loop_unroll_factor" in params:
257+
config_spec.range_unroll_factors.append(
258+
RangeUnrollFactorSpec([block_id])
259+
)
260+
if _supports_warp_specialize() and "warp_specialize" in params:
261+
config_spec.range_warp_specialize.append(
262+
RangeWarpSpecializeSpec([block_id])
263+
)
264+
if "num_stages" in params:
265+
config_spec.range_num_stages.append(RangeNumStagesSpec([block_id]))
266+
if "disallow_acc_multi_buffer" in params:
267+
config_spec.range_multi_buffers.append(RangeMultiBufferSpec([block_id]))
268+
if "flatten" in params:
269+
config_spec.range_flattens.append(RangeFlattenSpec([block_id]))
270+
271+
272+
def _supports_warp_specialize() -> bool:
273+
"""Check if the current device supports warp specialization."""
274+
env = CompileEnvironment.current()
275+
if env.device.type != "cuda" or not env.settings.allow_warp_specialize:
276+
return False
277+
return torch.cuda.get_device_capability() >= (12, 0)
256278

257279

258280
def _allow_use_yz_grid(config_spec: ConfigSpec, block_ids: list[int]) -> bool:

helion/runtime/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
l2_groupings: list[int] | None = None,
2727
reduction_loops: list[int | None] | None = None,
2828
range_unroll_factors: list[int] | None = None,
29+
range_warp_specializes: list[bool | None] | None = None,
2930
range_num_stages: list[int] | None = None,
3031
range_multi_buffers: list[bool | None] | None = None,
3132
range_flattens: list[bool | None] | None = None,
@@ -45,6 +46,7 @@ def __init__(
4546
l2_groupings: Reorders program IDs for L2 cache locality.
4647
reduction_loops: Configures reduction loop behavior.
4748
range_unroll_factors: Loop unroll factors for tl.range calls.
49+
range_warp_specializes: Warp specialization for tl.range calls.
4850
range_num_stages: Number of stages for tl.range calls.
4951
range_multi_buffers: Controls disallow_acc_multi_buffer for tl.range calls.
5052
range_flattens: Controls flatten parameter for tl.range calls.
@@ -62,6 +64,7 @@ def __init__(
6264
"l2_groupings": l2_groupings,
6365
"reduction_loops": reduction_loops,
6466
"range_unroll_factors": range_unroll_factors,
67+
"range_warp_specializes": range_warp_specializes,
6568
"range_num_stages": range_num_stages,
6669
"range_multi_buffers": range_multi_buffers,
6770
"range_flattens": range_flattens,
@@ -154,6 +157,10 @@ def use_yz_grid(self) -> bool:
154157
def range_unroll_factors(self) -> list[int]:
155158
return cast("list[int]", self.config.get("range_unroll_factors", []))
156159

160+
@property
161+
def range_warp_specializes(self) -> list[bool | None]:
162+
return cast("list[bool | None]", self.config.get("range_warp_specializes", []))
163+
157164
@property
158165
def range_num_stages(self) -> list[int]:
159166
return cast("list[int]", self.config.get("range_num_stages", []))

helion/runtime/settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ class _Settings:
6666
autotune_precompile: bool = sys.platform != "win32"
6767
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
6868
force_autotune: bool = os.environ.get("HELION_FORCE_AUTOTUNE", "0") == "1"
69+
allow_warp_specialize: bool = (
70+
os.environ.get("HELION_ALLOW_WARP_SPECIALIZE", "1") == "1"
71+
)
6972

7073

7174
class Settings(_Settings):
@@ -85,6 +88,7 @@ class Settings(_Settings):
8588
"autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.",
8689
"print_output_code": "If True, print the output code of the kernel to stderr.",
8790
"force_autotune": "If True, force autotuning even if a config is provided.",
91+
"allow_warp_specialize": "If True, allow warp specialization for tl.range calls on CUDA devices.",
8892
}
8993
assert __slots__.keys() == {field.name for field in dataclasses.fields(_Settings)}
9094

test/test_autotuner.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from helion.autotuner.config_generation import ConfigGeneration
2020
from helion.autotuner.random_search import RandomSearch
2121
import helion.language as hl
22+
from helion.language import loops
2223

2324
datadir = Path(__file__).parent / "data"
2425
basic_kernels = import_path(datadir / "basic_kernels.py")
@@ -34,6 +35,7 @@ def setUp(self):
3435
random.seed(112)
3536

3637
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)
38+
@patch.object(loops, "_supports_warp_specialize", lambda: True)
3739
def test_config_fragment0(self):
3840
args = (
3941
torch.randn([512, 512], device=DEVICE),
@@ -44,16 +46,16 @@ def test_config_fragment0(self):
4446
self.assertExpectedInline(
4547
"\n".join(map(repr, configs)),
4648
"""\
47-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
48-
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=8, num_stages=1, indexing='block_ptr')
49-
helion.Config(block_sizes=[16, 128, 32], loop_orders=[[1, 0]], l2_groupings=[4], range_unroll_factors=[1], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[True], num_warps=32, num_stages=3, indexing='tensor_descriptor')
50-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='tensor_descriptor')
51-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[4], range_num_stages=[3], range_multi_buffers=[None], range_flattens=[False], num_warps=32, num_stages=7, indexing='tensor_descriptor')
52-
helion.Config(block_sizes=[16, 64, 32], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[True], num_warps=32, num_stages=3, indexing='block_ptr')
53-
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[4], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=16, num_stages=6, indexing='pointer')
54-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=1, num_stages=6, indexing='block_ptr')
55-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[64], range_unroll_factors=[0], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[None], num_warps=32, num_stages=7, indexing='block_ptr')
56-
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[4], range_unroll_factors=[3], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[None], num_warps=8, num_stages=6, indexing='block_ptr')""",
49+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[0], range_multi_buffers=[None], range_flattens=[None], num_warps=4, num_stages=3, indexing='pointer')
50+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[8], range_unroll_factors=[1], range_warp_specializes=[False], range_num_stages=[1], range_multi_buffers=[True], range_flattens=[False], num_warps=1, num_stages=7, indexing='tensor_descriptor')
51+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[2], range_unroll_factors=[1], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[True], range_flattens=[True], num_warps=2, num_stages=8, indexing='tensor_descriptor')
52+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[1, 0]], l2_groupings=[1], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[1], range_multi_buffers=[False], range_flattens=[False], num_warps=32, num_stages=2, indexing='tensor_descriptor')
53+
helion.Config(block_sizes=[16, 32, 64], loop_orders=[[1, 0]], l2_groupings=[64], range_unroll_factors=[2], range_warp_specializes=[True], range_num_stages=[3], range_multi_buffers=[True], range_flattens=[None], num_warps=4, num_stages=7, indexing='pointer')
54+
helion.Config(block_sizes=[256, 128, 16], loop_orders=[[0, 1]], l2_groupings=[2], range_unroll_factors=[4], range_warp_specializes=[True], range_num_stages=[4], range_multi_buffers=[None], range_flattens=[False], num_warps=8, num_stages=4, indexing='tensor_descriptor')
55+
helion.Config(block_sizes=[16, 32, 16], loop_orders=[[1, 0]], l2_groupings=[16], range_unroll_factors=[0], range_warp_specializes=[True], range_num_stages=[2], range_multi_buffers=[None], range_flattens=[False], num_warps=1, num_stages=8, indexing='tensor_descriptor')
56+
helion.Config(block_sizes=[16, 16, 16], loop_orders=[[0, 1]], l2_groupings=[32], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[False], num_warps=4, num_stages=4, indexing='tensor_descriptor')
57+
helion.Config(block_sizes=[16, 16, 64], loop_orders=[[0, 1]], l2_groupings=[8], range_unroll_factors=[0], range_warp_specializes=[None], range_num_stages=[2], range_multi_buffers=[False], range_flattens=[True], num_warps=16, num_stages=4, indexing='block_ptr')
58+
helion.Config(block_sizes=[32, 16, 16], loop_orders=[[0, 1]], l2_groupings=[16], range_unroll_factors=[2], range_warp_specializes=[False], range_num_stages=[0], range_multi_buffers=[True], range_flattens=[False], num_warps=4, num_stages=1, indexing='tensor_descriptor')""",
5759
)
5860

5961
@patch.object(_compat, "_supports_tensor_descriptor", lambda: True)

0 commit comments

Comments
 (0)