Skip to content

Commit 485b219

Browse files
committed
Add hl.signal
stack-info: PR: #233, branch: joydddd/stack/8
1 parent bd0f27a commit 485b219

File tree

6 files changed

+362
-3
lines changed

6 files changed

+362
-3
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scan_ops import associative_scan as associative_scan
1616
from .scan_ops import cumprod as cumprod
1717
from .scan_ops import cumsum as cumsum
18+
from .signal_wait import signal as signal
1819
from .signal_wait import wait as wait
1920
from .tile_ops import tile_begin as tile_begin
2021
from .tile_ops import tile_block_size as tile_block_size

helion/language/signal_wait.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from torch.fx import has_side_effect
77

88
from .. import exc
9+
from .._compiler.indexing_strategy import SubscriptIndexing
910
from . import _decorators
1011

1112
if TYPE_CHECKING:
1213
import ast
1314

1415
from .._compiler.inductor_lowering import CodegenState
1516

17+
__all__ = ["signal", "wait"]
18+
1619

1720
@has_side_effect
1821
@_decorators.api(tiles_as_sizes=True)
@@ -146,3 +149,143 @@ def _(state: CodegenState) -> ast.AST:
146149
signal=signal_expr,
147150
update=update_expr,
148151
)
152+
153+
154+
@has_side_effect
155+
@_decorators.api(tiles_as_sizes=True)
156+
def signal(
157+
signal_pad: torch.Tensor,
158+
index: list[object],
159+
signal: int = 1,
160+
wait_for: int | None = None,
161+
op: str = "atomic_xchg",
162+
sem: str = "release",
163+
scope: str = "gpu",
164+
skip_sync: bool = False,
165+
) -> torch.Tensor:
166+
"""Set the signal_pad slice to the signal value.
167+
Args:
168+
signal_pad: The signal pad to signal
169+
index: Indices to index into the signal_pad tensor
170+
signal: the value to send
171+
wait_for: The value to wait for before sending the signal. Only valid for op = 'atomic_cas'.
172+
op: The memory op for acquring the lock (default: 'atomic_xchg')
173+
sem: The memory sematic for acquring the lock (default: 'release')
174+
scope: The scope of the lock (default: 'gpu')
175+
skip_sync: Skip the syncthreads before sending signal (default: False)
176+
"""
177+
raise exc.NotInsideKernel
178+
179+
180+
@_decorators.prepare_args(signal)
181+
def _(
182+
signal_pad: torch.Tensor,
183+
index: list[object],
184+
signal: int = 1,
185+
wait_for: int | None = None,
186+
op: str = "atomic_xchg",
187+
sem: str = "release",
188+
scope: str = "gpu",
189+
skip_sync: bool = False,
190+
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
191+
from helion.language.tile_proxy import Tile
192+
193+
valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"}
194+
valid_sems = {"relaxed", "release", "acq_rel"}
195+
valid_scopes = {"sys", "gpu"}
196+
197+
if op not in valid_ops:
198+
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")
199+
200+
if op == "atomic_cas" and wait_for is None:
201+
raise ValueError(
202+
f"{op} without a wait_for value. Do you want to use 'atomic_add' or 'atomic_xchg' instead? "
203+
)
204+
if op in {"atomic_add", "atomic_xchg"} and wait_for is not None:
205+
raise ValueError(
206+
f"{op} with a wait_for value. Do you want to use 'atomic_cas' instead? "
207+
)
208+
209+
if sem not in valid_sems:
210+
raise ValueError(
211+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
212+
)
213+
214+
if scope not in valid_scopes:
215+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
216+
217+
index = Tile._prepare_index(index)
218+
index = Tile._tiles_to_sizes(index)
219+
220+
return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync)
221+
222+
223+
@_decorators.register_fake(signal)
224+
def _(
225+
signal_pad: torch.Tensor,
226+
index: list[object],
227+
signal: int = 1,
228+
wait_for: int | None = None,
229+
op: str = "atomic_xchg",
230+
sem: str = "release",
231+
scope: str = "gpu",
232+
skip_sync: bool = False,
233+
) -> torch.Tensor:
234+
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))
235+
236+
237+
@_decorators.codegen(signal)
238+
def _(state: CodegenState) -> ast.AST:
239+
import ast
240+
241+
from .._compiler.ast_extension import expr_from_string
242+
from .._compiler.indexing_strategy import SubscriptIndexing
243+
244+
signal_pad = state.proxy_arg(0)
245+
index = state.proxy_arg(1)
246+
signal = state.proxy_arg(2)
247+
wait_for = state.proxy_arg(3)
248+
op = state.proxy_arg(4)
249+
sem = state.proxy_arg(5)
250+
scope = state.proxy_arg(6)
251+
skip_sync = state.proxy_arg(7)
252+
253+
assert isinstance(signal_pad, torch.Tensor)
254+
assert isinstance(index, list)
255+
256+
indices = SubscriptIndexing.create(state, signal_pad, index)
257+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
258+
259+
signal_expr = ast.Constant(value=signal)
260+
if wait_for is not None:
261+
wait_for_expr = ast.Constant(value=wait_for)
262+
else:
263+
wait_for_expr = ast.Constant(value=0)
264+
skip_sync_expr = ast.Constant(value=skip_sync)
265+
assert type(op) is str
266+
assert type(sem) is str
267+
assert type(scope) is str
268+
269+
if op == "atomic_cas":
270+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
271+
is_scalar = len(bar_tensor_shape) == 0
272+
if is_scalar:
273+
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
274+
else:
275+
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)"
276+
277+
return expr_from_string(
278+
call_triton_wait_signal,
279+
offset=indices.index_expr,
280+
wait_for=wait_for_expr,
281+
signal=signal_expr,
282+
skip_sync=skip_sync_expr,
283+
)
284+
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"
285+
286+
return expr_from_string(
287+
call_triton_send_signal,
288+
offset=indices.index_expr,
289+
signal=signal_expr,
290+
skip_sync=skip_sync_expr,
291+
)

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .config import Config as Config
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
11+
from .triton_helpers import triton_send_signal as triton_send_signal
1112
from .triton_helpers import triton_wait_signal as triton_wait_signal
1213

1314

helion/runtime/triton_helpers.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,56 @@
33
import triton
44
import triton.language as tl
55

6-
__all__ = ["triton_wait_signal"]
6+
__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]
7+
8+
9+
@triton.jit
10+
def triton_send_signal(
11+
addr: tl.tensor,
12+
update: tl.constexpr,
13+
sem: tl.constexpr,
14+
scope: tl.constexpr,
15+
op: tl.constexpr,
16+
skip_sync: tl.constexpr,
17+
) -> tl.tensor:
18+
"""
19+
Signal global memory barrier(s).
20+
21+
This function atomically sets global memory barriers to a update value,
22+
signaling to other CTAs waiting on the barrier(s).
23+
24+
Args:
25+
addr: Memory address of the barrier(s) to wait on
26+
update: Set the barrier to
27+
sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
28+
scope: Scope of the atomic operation. Options: "gpu", "sys"
29+
op: Atomic operation type: "atomic_xchg", "atomic_add"
30+
skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
31+
Returns:
32+
The old value of the barrier(s) before the update.
33+
"""
34+
if not skip_sync:
35+
tl.inline_asm_elementwise(
36+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
37+
)
38+
39+
tl.static_assert(
40+
sem == "release" or sem == "relaxed",
41+
"Invalid memory semantic. options: 'release', 'relaxed'. ",
42+
)
43+
tl.static_assert(
44+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
45+
)
46+
47+
if op == "atomic_xchg":
48+
barrier_status = tl.atomic_xchg(addr, update, sem=sem, scope=scope)
49+
elif op == "atomic_add":
50+
barrier_status = tl.atomic_add(addr, update, sem=sem, scope=scope)
51+
else:
52+
raise NotImplementedError(
53+
f"Unsupported op '{op}' for send signal on gmem barrier. "
54+
)
55+
return barrier_status
756

857

958
@triton.jit
@@ -15,6 +64,7 @@ def triton_wait_signal(
1564
scope: tl.constexpr,
1665
op: tl.constexpr,
1766
skip_sync: tl.constexpr,
67+
sync_before: tl.constexpr = False, # pyre-ignore[9]
1868
) -> None:
1969
"""
2070
Wait for a global memory barrier to reach the expected value.
@@ -30,15 +80,16 @@ def triton_wait_signal(
3080
scope: Scope of the atomic operation. Options: "gpu", "sys"
3181
op: Atomic operation type: "ld", "atomic_cas"
3282
skip_sync: Skip CTA sync after acquiring the barrier (default: False)
83+
sync_before: Add a CTA sync before the wait (default: False)
3384
"""
3485
tl.static_assert(
3586
addr.type.is_ptr(),
3687
"Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ",
3788
)
3889

3990
tl.static_assert(
40-
sem == "acquire" or sem == "relaxed",
41-
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
91+
(sem == "acquire" or sem == "relaxed") or sem == "release",
92+
"Invalid memory semantic. options: 'acquire', 'relaxed', 'release'. ",
4293
)
4394
tl.static_assert(
4495
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
@@ -48,6 +99,11 @@ def triton_wait_signal(
4899
"Invalid op. options: 'ld', 'atomic_cas'. ",
49100
)
50101

102+
if sync_before:
103+
tl.inline_asm_elementwise(
104+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
105+
)
106+
51107
# Spin-wait loop:
52108
# Uses atomic_add with update=0 for ld.global.{sem}.{scope}
53109
# Triton generates smem broadcasting of tl.atomic_add return value in ptx,
@@ -71,3 +127,18 @@ def triton_wait_signal(
71127
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
72128
)
73129
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)
130+
131+
132+
@triton.jit
133+
def triton_wait_multiple_signal(
134+
addr: tl.tensor,
135+
expect: tl.constexpr, # wait until lock is set to expect
136+
update: tl.constexpr, # update the lock once it is aquired.
137+
sem: tl.constexpr,
138+
scope: tl.constexpr,
139+
op: tl.constexpr,
140+
skip_sync: tl.constexpr,
141+
sync_before: tl.constexpr = False, # pyre-ignore[9]
142+
) -> None:
143+
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
144+
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier

test/test_signal_wait.expected

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,81 @@
11
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestWait.test_signal_basic)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
12+
@triton.jit
13+
def _gmem_signal_scalar_bar_kernel_kernel(signal_pad, signal_pad_stride_0):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0
16+
helion.runtime.triton_send_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
17+
18+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor):
19+
n, = signal_pad.shape
20+
_gmem_signal_scalar_bar_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
21+
return signal_pad
22+
23+
def _gmem_signal_scalar_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
24+
n, = signal_pad.shape
25+
from helion.runtime.precompile_shim import make_precompiler
26+
return make_precompiler(_gmem_signal_scalar_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
27+
28+
--- assertExpectedJournal(TestWait.test_signal_cas)
29+
from __future__ import annotations
30+
31+
import torch
32+
import helion
33+
import triton
34+
import triton.language as tl
35+
36+
@triton.jit
37+
def _gmem_signal_cas_kernel_kernel(signal_pad, signal_pad_stride_0):
38+
pid_0 = tl.program_id(0)
39+
offset_0 = pid_0
40+
helion.runtime.triton_wait_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False)
41+
42+
def gmem_signal_cas_kernel(signal_pad: torch.Tensor):
43+
n, = signal_pad.shape
44+
_gmem_signal_cas_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
45+
return signal_pad
46+
47+
def _gmem_signal_cas_kernel_make_precompiler(signal_pad: torch.Tensor):
48+
n, = signal_pad.shape
49+
from helion.runtime.precompile_shim import make_precompiler
50+
return make_precompiler(_gmem_signal_cas_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
51+
52+
--- assertExpectedJournal(TestWait.test_signal_multiple)
53+
from __future__ import annotations
54+
55+
import torch
56+
import helion
57+
import triton
58+
import triton.language as tl
59+
60+
@triton.jit
61+
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
62+
pid_0 = tl.program_id(0)
63+
offset_0 = pid_0 * _BLOCK_SIZE_0
64+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
65+
helion.runtime.triton_send_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
66+
67+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
68+
n, = signal_pad.shape
69+
_BLOCK_SIZE_0 = 4
70+
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
71+
return signal_pad
72+
73+
def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
74+
n, = signal_pad.shape
75+
_BLOCK_SIZE_0 = 4
76+
from helion.runtime.precompile_shim import make_precompiler
77+
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
78+
479
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
580
from __future__ import annotations
681

0 commit comments

Comments
 (0)