Skip to content

Commit a3482a3

Browse files
committed
Add hl.signal
stack-info: PR: #233, branch: joydddd/stack/8
1 parent 3044010 commit a3482a3

File tree

6 files changed

+275
-1
lines changed

6 files changed

+275
-1
lines changed

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scan_ops import associative_scan as associative_scan
1616
from .scan_ops import cumprod as cumprod
1717
from .scan_ops import cumsum as cumsum
18+
from .signal_wait import signal as signal
1819
from .signal_wait import wait as wait
1920
from .tile_ops import tile_begin as tile_begin
2021
from .tile_ops import tile_block_size as tile_block_size

helion/language/signal_wait.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from torch.fx import has_side_effect
77

88
from .. import exc
9+
from .._compiler.indexing_strategy import SubscriptIndexing
910
from . import _decorators
1011

1112
if TYPE_CHECKING:
1213
import ast
1314

1415
from .._compiler.inductor_lowering import CodegenState
1516

17+
__all__ = ["signal", "wait"]
18+
1619

1720
@has_side_effect
1821
@_decorators.api(tiles_as_sizes=True)
@@ -146,3 +149,108 @@ def _(state: CodegenState) -> ast.AST:
146149
signal=signal_expr,
147150
update=update_expr,
148151
)
152+
153+
154+
@has_side_effect
155+
@_decorators.api(tiles_as_sizes=True)
156+
def signal(
157+
signal_pad: torch.Tensor,
158+
index: list[object],
159+
signal: int = 1,
160+
op: str = "atomic_xchg",
161+
sem: str = "release",
162+
scope: str = "gpu",
163+
skip_sync: bool = False,
164+
) -> torch.Tensor:
165+
"""Set the signal_pad slice to the signal value.
166+
Args:
167+
signal_pad: The signal pad to signal
168+
index: Indices to index into the signal_pad tensor
169+
signal: the value to send
170+
op: The memory op for acquring the lock (default: 'atomic_xchg')
171+
sem: The memory sematic for acquring the lock (default: 'release')
172+
scope: The scope of the lock (default: 'gpu')
173+
skip_sync: Skip the syncthreads before sending signal (default: False)
174+
"""
175+
raise exc.NotInsideKernel
176+
177+
178+
@_decorators.prepare_args(signal)
179+
def _(
180+
signal_pad: torch.Tensor,
181+
index: list[object],
182+
signal: int = 1,
183+
op: str = "atomic_xchg",
184+
sem: str = "release",
185+
scope: str = "gpu",
186+
skip_sync: bool = False,
187+
) -> tuple[torch.Tensor, object, int, str, str, str, bool]:
188+
from helion.language.tile_proxy import Tile
189+
190+
valid_ops = {"atomic_add", "atomic_xchg"}
191+
valid_sems = {"relaxed", "release", "acq_rel"}
192+
valid_scopes = {"sys", "gpu"}
193+
194+
if op not in valid_ops:
195+
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")
196+
197+
if sem not in valid_sems:
198+
raise ValueError(
199+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
200+
)
201+
202+
if scope not in valid_scopes:
203+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
204+
205+
index = Tile._prepare_index(index)
206+
index = Tile._tiles_to_sizes(index)
207+
208+
return (signal_pad, index, signal, op, sem, scope, skip_sync)
209+
210+
211+
@_decorators.register_fake(signal)
212+
def _(
213+
signal_pad: torch.Tensor,
214+
index: list[object],
215+
signal: int = 1,
216+
op: str = "atomic_xchg",
217+
sem: str = "release",
218+
scope: str = "gpu",
219+
skip_sync: bool = False,
220+
) -> torch.Tensor:
221+
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))
222+
223+
224+
@_decorators.codegen(signal)
225+
def _(state: CodegenState) -> ast.AST:
226+
import ast
227+
228+
from .._compiler.ast_extension import expr_from_string
229+
from .._compiler.indexing_strategy import SubscriptIndexing
230+
231+
signal_pad = state.proxy_arg(0)
232+
index = state.proxy_arg(1)
233+
signal = state.proxy_arg(2)
234+
op = state.proxy_arg(3)
235+
sem = state.proxy_arg(4)
236+
scope = state.proxy_arg(5)
237+
skip_sync = state.proxy_arg(6)
238+
239+
assert isinstance(signal_pad, torch.Tensor)
240+
assert isinstance(index, list)
241+
242+
indices = SubscriptIndexing.create(state, signal_pad, index)
243+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
244+
245+
signal_expr = ast.Constant(value=signal)
246+
assert type(op) is str
247+
assert type(sem) is str
248+
assert type(scope) is str
249+
250+
hl_ext_call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
251+
252+
return expr_from_string(
253+
hl_ext_call_triton_send_signal,
254+
offset=indices.index_expr,
255+
signal=signal_expr,
256+
)

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .config import Config as Config
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
11+
from .triton_helpers import triton_send_signal as triton_send_signal
1112
from .triton_helpers import triton_wait_signal as triton_wait_signal
1213

1314

helion/runtime/triton_helpers.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,53 @@
33
import triton
44
import triton.language as tl
55

6-
__all__ = ["triton_wait_signal"]
6+
__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]
7+
8+
9+
@triton.jit
10+
def triton_send_signal(
11+
addr: tl.tensor,
12+
update: tl.constexpr,
13+
sem: tl.constexpr,
14+
scope: tl.constexpr,
15+
op: tl.constexpr,
16+
skip_sync: tl.constexpr,
17+
) -> None:
18+
"""
19+
Signal global memory barrier(s).
20+
21+
This function atomically sets global memory barriers to a update value,
22+
signaling to other CTAs waiting on the barrier(s).
23+
24+
Args:
25+
addr: Memory address of the barrier(s) to wait on
26+
update: Set the barrier to
27+
sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
28+
scope: Scope of the atomic operation. Options: "gpu", "sys"
29+
op: Atomic operation type: "atomic_xchg", "atomic_add"
30+
skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
31+
"""
32+
if not skip_sync:
33+
tl.inline_asm_elementwise(
34+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
35+
)
36+
37+
tl.static_assert(
38+
sem == "release" or sem == "relaxed",
39+
"Invalid memory semantic. options: 'release', 'relaxed'. ",
40+
)
41+
tl.static_assert(
42+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
43+
)
44+
45+
if op == "atomic_xchg":
46+
tl.atomic_xchg(addr, update, sem=sem, scope=scope)
47+
elif op == "atomic_add":
48+
tl.atomic_add(addr, update, sem=sem, scope=scope)
49+
else:
50+
raise NotImplementedError(
51+
f"Unsupported op '{op}' for send signal on gmem barrier. "
52+
)
753

854

955
@triton.jit
@@ -71,3 +117,17 @@ def triton_wait_signal(
71117
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
72118
)
73119
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)
120+
121+
122+
@triton.jit
123+
def triton_wait_multiple_signal(
124+
addr: tl.tensor,
125+
expect: tl.constexpr, # wait until lock is set to expect
126+
update: tl.constexpr, # update the lock once it is aquired.
127+
sem: tl.constexpr,
128+
scope: tl.constexpr,
129+
op: tl.constexpr,
130+
skip_sync: tl.constexpr,
131+
) -> None:
132+
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
133+
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier

test/test_signal_wait.expected

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,57 @@
11
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestWait.test_signal_basic)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
12+
@triton.jit
13+
def _gmem_signal_scalar_bar_kernel_kernel(signal_pad, signal_pad_stride_0):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0
16+
helion.runtime.triton_send_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
17+
18+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor):
19+
n, = signal_pad.shape
20+
_gmem_signal_scalar_bar_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
21+
return signal_pad
22+
23+
def _gmem_signal_scalar_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
24+
n, = signal_pad.shape
25+
from helion.runtime.precompile_shim import make_precompiler
26+
return make_precompiler(_gmem_signal_scalar_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
27+
28+
--- assertExpectedJournal(TestWait.test_signal_multiple)
29+
from __future__ import annotations
30+
31+
import torch
32+
import helion
33+
import triton
34+
import triton.language as tl
35+
36+
@triton.jit
37+
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
38+
pid_0 = tl.program_id(0)
39+
offset_0 = pid_0 * _BLOCK_SIZE_0
40+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
41+
helion.runtime.triton_send_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)
42+
43+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
44+
n, = signal_pad.shape
45+
_BLOCK_SIZE_0 = 4
46+
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
47+
return signal_pad
48+
49+
def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
50+
n, = signal_pad.shape
51+
_BLOCK_SIZE_0 = 4
52+
from helion.runtime.precompile_shim import make_precompiler
53+
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
54+
455
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
556
from __future__ import annotations
657

test/test_signal_wait.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,59 @@ def wait_for_2d_tile_kernel(
5454
torch.testing.assert_close(result, x)
5555
self.assertExpectedJournal(code)
5656

57+
def test_signal_basic(self):
58+
@helion.kernel
59+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
60+
(n,) = signal_pad.shape
61+
for i in hl.grid(n):
62+
hl.signal(signal_pad, [i], signal=1)
63+
return signal_pad
64+
65+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
66+
code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,))
67+
torch.testing.assert_close(
68+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
69+
)
70+
self.assertExpectedJournal(code)
71+
72+
def test_signal_multiple(self):
73+
@helion.kernel
74+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
75+
(n,) = signal_pad.shape
76+
for tile in hl.tile(n):
77+
hl.signal(signal_pad, [tile], signal=1)
78+
return signal_pad
79+
80+
signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32)
81+
code, result = code_and_output(
82+
gmem_signal_tensor_bar_kernel,
83+
(signal_pad,),
84+
block_size=[4],
85+
)
86+
torch.testing.assert_close(
87+
result, torch.ones(16, device=DEVICE, dtype=torch.int32)
88+
)
89+
self.assertExpectedJournal(code)
90+
91+
def test_sent_recieve_cta(self):
92+
@helion.kernel
93+
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
94+
(n,) = signal_pad.shape
95+
for i in hl.grid(n): # first N ctas sends signal
96+
hl.signal(signal_pad, [i], signal=1)
97+
for i in hl.grid(n): # last N ctas waits for signal
98+
hl.wait(signal_pad, [i], signal=1)
99+
return signal_pad
100+
101+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
102+
103+
code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,))
104+
torch.testing.assert_close(
105+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
106+
)
107+
self.assertIn("helion.runtime.triton_send_signal", code)
108+
self.assertIn("helion.runtime.triton_wait_signal", code)
109+
57110

58111
if __name__ == "__main__":
59112
unittest.main()

0 commit comments

Comments
 (0)