Skip to content

Commit 9f39f63

Browse files
committed
Add hl.signal
stack-info: PR: #233, branch: joydddd/stack/8
1 parent df1c9b1 commit 9f39f63

File tree

5 files changed

+202
-1
lines changed

5 files changed

+202
-1
lines changed

helion/_triton_ext/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3+
from .gmem_barrier import _triton_send_signal
34
from .gmem_barrier import _triton_wait_multiple_signal
45
from .gmem_barrier import _triton_wait_signal
56

6-
__all__ = ["_triton_wait_multiple_signal", "_triton_wait_signal"]
7+
__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"]

helion/_triton_ext/gmem_barrier.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,51 @@
44
import triton
55
import triton.language as tl
66

7+
__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"]
8+
9+
10+
@triton.jit
11+
def _triton_send_signal(
12+
addr, # can be a scalar or a vector of pointers.
13+
update: tl.constexpr,
14+
sem: tl.constexpr,
15+
scope: tl.constexpr,
16+
op: tl.constexpr,
17+
skip_sync: tl.constexpr,
18+
) -> None:
19+
"""
20+
Send a signal to a global memory barrier.
21+
22+
This function implements a spin-wait loop that continuously checks a memory location
23+
until it reaches the expected value, providing synchronization across GPU threads.
24+
25+
Args:
26+
addr: Memory address of the barrier to wait on (Must be a scalar)
27+
expect: Expected value to wait for
28+
update: Update
29+
"""
30+
if not skip_sync:
31+
tl.inline_asm_elementwise(
32+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
33+
)
34+
35+
tl.static_assert(
36+
sem == "release" or sem == "relaxed",
37+
"Invalid memory semantic. options: 'release', 'relaxed'. ",
38+
)
39+
tl.static_assert(
40+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
41+
)
42+
43+
if op == "atomic_xchg":
44+
tl.atomic_xchg(addr, update, sem=sem, scope=scope)
45+
elif op == "atomic_add":
46+
tl.atomic_add(addr, update, sem=sem, scope=scope)
47+
else:
48+
raise NotImplementedError(
49+
f"Unsupported op '{op}' for send signal on gmem barrier. "
50+
)
51+
752

853
@triton.jit
954
def _triton_wait_signal(

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .memory_ops import atomic_add as atomic_add
1212
from .memory_ops import load as load
1313
from .memory_ops import store as store
14+
from .signal_wait import signal as signal
1415
from .signal_wait import wait as wait
1516
from .tile_ops import tile_begin as tile_begin
1617
from .tile_ops import tile_block_size as tile_block_size

helion/language/signal_wait.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
from torch.fx import has_side_effect
77

88
from .. import exc
9+
from .._compiler.indexing_strategy import SubscriptIndexing
910
from . import _decorators
1011

1112
if TYPE_CHECKING:
1213
import ast
1314

1415
from .._compiler.inductor_lowering import CodegenState
16+
from helion._compiler.type_propagation import SymIntType
17+
18+
__all__ = ["signal", "wait"]
1519

1620

1721
@has_side_effect
@@ -195,3 +199,100 @@ def _(state: CodegenState) -> ast.AST:
195199
signal=signal_expr,
196200
update=update_expr,
197201
)
202+
203+
204+
@has_side_effect
205+
@_decorators.api(tiles_as_sizes=True)
206+
def signal(
207+
signal_pad: torch.Tensor,
208+
index: list[object],
209+
signal: int = 1,
210+
op: str = "atomic_xchg",
211+
sem: str = "release",
212+
scope: str = "gpu",
213+
skip_sync: bool = False,
214+
) -> torch.Tensor | SymIntType:
215+
raise exc.NotInsideKernel
216+
217+
218+
@_decorators.prepare_args(signal)
219+
def _(
220+
signal_pad: torch.Tensor,
221+
index: list[object],
222+
signal: int = 1,
223+
op: str = "atomic_xchg",
224+
sem: str = "release",
225+
scope: str = "gpu",
226+
skip_sync: bool = False,
227+
) -> tuple[torch.Tensor, object, int, str, str, str, bool]:
228+
from helion.language.tile_proxy import Tile
229+
230+
valid_ops = {"atomic_add", "atomic_xchg"}
231+
valid_sems = {"relaxed", "release", "acq_rel"}
232+
valid_scopes = {"sys", "gpu"}
233+
234+
if op not in valid_ops:
235+
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")
236+
237+
if sem not in valid_sems:
238+
raise ValueError(
239+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
240+
)
241+
242+
if scope not in valid_scopes:
243+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
244+
245+
index = Tile._prepare_index(index)
246+
index = Tile._tiles_to_sizes(index)
247+
248+
return (signal_pad, index, signal, op, sem, scope, skip_sync)
249+
250+
251+
@_decorators.register_fake(signal)
252+
def _(
253+
signal_pad: torch.Tensor,
254+
index: list[object],
255+
signal: int = 1,
256+
op: str = "atomic_xchg",
257+
sem: str = "release",
258+
scope: str = "gpu",
259+
skip_sync: bool = False,
260+
) -> torch.Tensor:
261+
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))
262+
263+
264+
@_decorators.codegen(signal)
265+
def _(state: CodegenState) -> ast.AST:
266+
import ast
267+
268+
from .._compiler.ast_extension import expr_from_string
269+
from .._compiler.indexing_strategy import SubscriptIndexing
270+
271+
signal_pad = state.proxy_arg(0)
272+
index = state.proxy_arg(1)
273+
signal = state.proxy_arg(2)
274+
op = state.proxy_arg(3)
275+
sem = state.proxy_arg(4)
276+
scope = state.proxy_arg(5)
277+
skip_sync = state.proxy_arg(6)
278+
279+
assert isinstance(signal_pad, torch.Tensor)
280+
assert isinstance(index, list)
281+
282+
indices = SubscriptIndexing.create(state, signal_pad, index)
283+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
284+
285+
signal_expr = ast.Constant(value=signal)
286+
assert type(op) is str
287+
assert type(sem) is str
288+
assert type(scope) is str
289+
290+
hl_ext_call_triton_send_signal = f"hl_ext._triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
291+
292+
# lock_spin_ptx = get_lock_spin_ptx(signal_pad_name, op, sem, scope)
293+
294+
return expr_from_string(
295+
hl_ext_call_triton_send_signal,
296+
offset=indices.index_expr,
297+
signal=signal_expr,
298+
)

test/test_signal_wait.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,59 @@ def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch
101101
code,
102102
)
103103

104+
def test_signal_basic(self):
105+
@helion.kernel
106+
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
107+
(n,) = signal_pad.shape
108+
for i in hl.grid(n):
109+
hl.signal(signal_pad, [i], signal=1)
110+
return signal_pad
111+
112+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
113+
code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,))
114+
torch.testing.assert_close(
115+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
116+
)
117+
self.assertIn("hl_ext._triton_send_signal", code)
118+
119+
def test_signal_multiple(self):
120+
@helion.kernel
121+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
122+
(n,) = signal_pad.shape
123+
for tile in hl.tile(n):
124+
hl.signal(signal_pad, [tile], signal=1)
125+
return signal_pad
126+
127+
signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32)
128+
code, result = code_and_output(
129+
gmem_signal_tensor_bar_kernel,
130+
(signal_pad,),
131+
block_size=[4],
132+
)
133+
torch.testing.assert_close(
134+
result, torch.ones(16, device=DEVICE, dtype=torch.int32)
135+
)
136+
self.assertIn("hl_ext._triton_send_signal", code)
137+
138+
def test_sent_recieve_cta(self):
139+
@helion.kernel
140+
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
141+
(n,) = signal_pad.shape
142+
for i in hl.grid(n): # first N ctas sends signal
143+
hl.signal(signal_pad, [i], signal=1)
144+
for i in hl.grid(n): # last N ctas waits for signal
145+
hl.wait(signal_pad, [i], signal=1)
146+
return signal_pad
147+
148+
signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
149+
150+
code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,))
151+
torch.testing.assert_close(
152+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
153+
)
154+
self.assertIn("hl_ext._triton_send_signal", code)
155+
self.assertIn("hl_ext._triton_wait_signal", code)
156+
104157

105158
if __name__ == "__main__":
106159
unittest.main()

0 commit comments

Comments
 (0)