Skip to content

Commit 915fd9f

Browse files
committed
Add hl.signal
stack-info: PR: #233, branch: joydddd/stack/8
1 parent 67a72f1 commit 915fd9f

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed

helion/_triton_ext/gmem_barrier.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,49 @@
55
import triton.language as tl
66

77

8+
@triton.jit
9+
def _triton_send_signal(
10+
addr, # can be a scalar or a vector of pointers.
11+
update: tl.constexpr,
12+
sem: tl.constexpr,
13+
scope: tl.constexpr,
14+
op: tl.constexpr,
15+
skip_sync: tl.constexpr,
16+
) -> None:
17+
"""
18+
Send a signal to a global memory barrier.
19+
20+
This function implements a spin-wait loop that continuously checks a memory location
21+
until it reaches the expected value, providing synchronization across GPU threads.
22+
23+
Args:
24+
addr: Memory address of the barrier to wait on (Must be a scalar)
25+
expect: Expected value to wait for
26+
update: Update
27+
"""
28+
if not skip_sync:
29+
tl.inline_asm_elementwise(
30+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
31+
)
32+
33+
tl.static_assert(
34+
sem == "release" or sem == "relaxed",
35+
"Invalid memory semantic. options: 'release', 'relaxed'. ",
36+
)
37+
tl.static_assert(
38+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
39+
)
40+
41+
if op == "atomic_xchg":
42+
tl.atomic_xchg(addr, update, sem=sem, scope=scope)
43+
elif op == "atomic_add":
44+
tl.atomic_add(addr, update, sem=sem, scope=scope)
45+
else:
46+
raise NotImplementedError(
47+
f"Unsupported op '{op}' for send signal on gmem barrier. "
48+
)
49+
50+
851
@triton.jit
952
def _triton_wait_signal(
1053
addr,

helion/language/signal_wait.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import ast
1313

1414
from .._compiler.inductor_lowering import CodegenState
15+
from helion._compiler.type_propagation import SymIntType
16+
17+
__all__ = ["signal", "wait"]
1518

1619

1720
@has_side_effect
@@ -195,3 +198,30 @@ def _(state: CodegenState) -> ast.AST:
195198
signal=signal_expr,
196199
update=update_expr,
197200
)
201+
202+
203+
@has_side_effect
204+
@_decorators.api(tiles_as_sizes=True)
205+
def signal(
206+
signal_pad: torch.Tensor,
207+
index: list[object],
208+
signal: int = 1,
209+
op: str = "atomic_xchg",
210+
sem: str = "release",
211+
scope: str = "gpu",
212+
skip_sync: bool = False,
213+
) -> torch.Tensor | SymIntType:
214+
raise exc.NotInsideKernel
215+
216+
217+
@_decorators.register_fake(signal)
218+
def _(
219+
signal_pad: torch.Tensor,
220+
index: list[object],
221+
signal: int = 1,
222+
op: str = "atomic_xchg",
223+
sem: str = "release",
224+
scope: str = "gpu",
225+
skip_sync: bool = False,
226+
) -> torch.Tensor:
227+
return tensor.new_empty(SubscriptIndexing.compute_shape(tensor, index))

test/test_signal_wait.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,34 @@ def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch
101101
code,
102102
)
103103

104+
def test_basic_signal(self):
105+
@helion.kernel
106+
def gmem_signal_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
107+
(n,) = signal_pad.shape
108+
for i in hl.grid(n):
109+
hl.signal(signal_pad, [i], signal=1)
110+
return signal_pad
111+
112+
signal_pad = torch.ones(4, device=DEVICE, dtype=torch.int32)
113+
code, result = code_and_output(gmem_wait_kernel, (signal_pad,))
114+
torch.testing.assert_close(
115+
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
116+
)
117+
self.maxDiff = None
118+
self.assertIn(
119+
"from helion import _triton_ext as hl_ext", code
120+
) # Import hl_ext.
121+
self.assertIn(
122+
"""\
123+
@triton.jit
124+
def _gmem_wait_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0):
125+
pid_0 = tl.program_id(0)
126+
offset_0 = pid_0
127+
hl_ext._triton_wait_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
128+
tl.store(out + offset_0 * out_stride_0, offset_0, None)""",
129+
code,
130+
)
131+
104132

105133
if __name__ == "__main__":
106134
unittest.main()

0 commit comments

Comments
 (0)