Skip to content

Commit c45e1dc

Browse files
committed
Add hl.wait for simultenous waiting for multiple gmem barriers
stack-info: PR: #243, branch: joydddd/stack/11
1 parent 340876e commit c45e1dc

File tree

5 files changed

+219
-10
lines changed

5 files changed

+219
-10
lines changed

helion/language/signal_wait.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,6 @@ def _(
8888
if scope not in valid_scopes:
8989
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
9090

91-
# TODO(joydddd): add support for non scalar index into signal_pad
92-
for i in index:
93-
assert isinstance(i, int | torch.SymInt)
94-
9591
index = Tile._prepare_index(index)
9692
index = Tile._tiles_to_sizes(index)
9793

@@ -141,7 +137,17 @@ def _(state: CodegenState) -> ast.AST:
141137
assert type(sem) is str
142138
assert type(scope) is str
143139

144-
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
140+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
141+
is_scalar = len(bar_tensor_shape) == 0
142+
143+
if is_scalar:
144+
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
145+
else:
146+
if signal_pad.dtype not in (torch.int32, torch.uint32):
147+
raise NotImplementedError(
148+
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
149+
)
150+
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
145151

146152
return expr_from_string(
147153
call_triton_wait_signal,

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
1111
from .triton_helpers import triton_send_signal as triton_send_signal
12+
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
1213
from .triton_helpers import triton_wait_signal as triton_wait_signal
1314

1415

helion/runtime/triton_helpers.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,106 @@ def triton_wait_signal(
122122
@triton.jit
123123
def triton_wait_multiple_signal(
124124
addr: tl.tensor,
125-
expect: tl.constexpr, # wait until lock is set to expect
126-
update: tl.constexpr, # update the lock once it is aquired.
125+
expect: tl.constexpr,
126+
update: tl.constexpr,
127127
sem: tl.constexpr,
128128
scope: tl.constexpr,
129129
op: tl.constexpr,
130130
skip_sync: tl.constexpr,
131131
) -> None:
132-
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
133-
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
132+
"""
133+
Simultenuoslly wait for multiple global memory barrier to reach the expected value.
134+
135+
This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.
136+
137+
Args:
138+
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
139+
expect: Expected value to wait for
140+
update: Update the barrier with once acquired
141+
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
142+
scope: Scope of the atomic operation. Options: "gpu", "sys"
143+
op: Atomic operation type: "ld", "atomic_cas"
144+
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
145+
"""
146+
tl.static_assert(
147+
sem == "acquire" or sem == "relaxed",
148+
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
149+
)
150+
tl.static_assert(
151+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
152+
)
153+
tl.static_assert(
154+
op == "ld" or op == "atomic_cas",
155+
"Invalid op. options: 'ld', 'atomic_cas'. ",
156+
)
157+
158+
tl.static_assert(
159+
addr.dtype == tl.pointer_type(tl.int32),
160+
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
161+
)
162+
163+
addr = tl.ravel(addr)
164+
165+
tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
166+
tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ")
167+
168+
# Assume Triton always sets tid.y == tid.z == 0.
169+
if op == "ld":
170+
tl.inline_asm_elementwise(
171+
f"""
172+
{{
173+
.reg .u32 %tmp32_<3>;
174+
.reg .pred %p<2>;
175+
176+
mov.u32 %tmp32_0, %tid.x;
177+
setp.le.s32 %p1, %tmp32_0, $2;
178+
179+
mov.u32 $0, 0;
180+
// initialize tmp_0 to 0
181+
wait_block:
182+
@%p1 ld.global.{sem}.{scope}.u32 $0, [$1];
183+
setp.ne.u32 %p0, $0, $3;
184+
and.pred %p0, %p0, %p1;
185+
@%p0 bra wait_block;
186+
}}
187+
""",
188+
"=r, l, r, r",
189+
[addr, addr.shape[0], expect],
190+
dtype=addr.dtype.element_ty,
191+
is_pure=False,
192+
pack=1,
193+
)
194+
elif op == "atomic_cas":
195+
tl.inline_asm_elementwise(
196+
f"""
197+
{{
198+
.reg .u32 %tmp32_<3>;
199+
.reg .pred %p<2>;
200+
201+
mov.u32 %tmp32_0, %tid.x;
202+
setp.le.s32 %p1, %tmp32_0, $2;
203+
204+
mov.u32 $0, 0;
205+
// initialize tmp_0 to 0
206+
wait_block:
207+
@%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4;
208+
setp.ne.u32 %p0, $0, $3;
209+
and.pred %p0, %p0, %p1;
210+
@%p0 bra wait_block;
211+
}}
212+
""",
213+
"=r, l, r, r, r",
214+
[addr, addr.shape[0], expect, update],
215+
dtype=addr.dtype.element_ty,
216+
is_pure=False,
217+
pack=1,
218+
)
219+
else:
220+
raise NotImplementedError(
221+
f"Unsupported op '{op}' for wait signal on gmem barrier. "
222+
)
223+
224+
if not skip_sync:
225+
tl.inline_asm_elementwise(
226+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
227+
)

test/test_signal_wait.expected

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestWait.test_global_sync)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
12+
@triton.jit
13+
def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0
16+
for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
17+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
18+
helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True)
19+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
20+
21+
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor):
22+
M, N = signal_pad.shape
23+
assert M == N
24+
_BLOCK_SIZE_1 = N
25+
_gmem_multi_bar_sync_kernel_kernel[N,](signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
26+
return signal_pad
27+
28+
def _gmem_multi_bar_sync_kernel_make_precompiler(signal_pad: torch.Tensor):
29+
M, N = signal_pad.shape
30+
assert M == N
31+
_BLOCK_SIZE_1 = N
32+
from helion.runtime.precompile_shim import make_precompiler
33+
return make_precompiler(_gmem_multi_bar_sync_kernel_kernel)(signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
34+
435
--- assertExpectedJournal(TestWait.test_signal_basic)
536
from __future__ import annotations
637

@@ -120,3 +151,38 @@ def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor):
120151
from helion.runtime.precompile_shim import make_precompiler
121152
return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3)
122153

154+
--- assertExpectedJournal(TestWait.test_wait_multi_bar)
155+
from __future__ import annotations
156+
157+
import torch
158+
import helion
159+
import triton
160+
import triton.language as tl
161+
162+
import __main__ as _source_module
163+
164+
@triton.jit
165+
def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
166+
pid_0 = tl.program_id(0)
167+
offset_0 = pid_0 * _BLOCK_SIZE_0
168+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
169+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
170+
tile_id = offset_0 // _BLOCK_SIZE_0
171+
tl.store(out + tile_id * out_stride_0, tile_id, None)
172+
173+
def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor):
174+
N, = signal_pad.shape
175+
n = 4
176+
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
177+
_BLOCK_SIZE_0 = 4
178+
_gmem_wait_multi_bar_kernel_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
179+
return out
180+
181+
def _gmem_wait_multi_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
182+
N, = signal_pad.shape
183+
n = 4
184+
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
185+
_BLOCK_SIZE_0 = 4
186+
from helion.runtime.precompile_shim import make_precompiler
187+
return make_precompiler(_gmem_wait_multi_bar_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
188+

test/test_signal_wait.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,29 @@ def wait_for_2d_tile_kernel(
5454
torch.testing.assert_close(result, x)
5555
self.assertExpectedJournal(code)
5656

57+
def test_wait_multi_bar(self):
58+
@helion.kernel
59+
def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
60+
(N,) = signal_pad.shape
61+
n = hl.register_block_size(N)
62+
out = torch.empty(n, dtype=torch.int32, device=DEVICE)
63+
64+
for tile in hl.tile(N, block_size=n):
65+
hl.wait(signal_pad, [tile], signal=1)
66+
out[tile.id] = tile.id
67+
68+
return out
69+
70+
signal_pad = torch.ones(16, device=DEVICE, dtype=torch.int32)
71+
code, result = code_and_output(
72+
gmem_wait_multi_bar_kernel, (signal_pad,), block_size=[4]
73+
)
74+
torch.testing.assert_close(
75+
result, torch.arange(4, device=DEVICE, dtype=torch.int32)
76+
)
77+
self.maxDiff = None
78+
self.assertExpectedJournal(code)
79+
5780
def test_signal_basic(self):
5881
@helion.kernel
5982
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -88,7 +111,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
88111
)
89112
self.assertExpectedJournal(code)
90113

91-
def test_sent_recieve_cta(self):
114+
def test_send_recieve_cta(self):
92115
@helion.kernel
93116
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
94117
(n,) = signal_pad.shape
@@ -107,6 +130,25 @@ def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
107130
self.assertIn("helion.runtime.triton_send_signal", code)
108131
self.assertIn("helion.runtime.triton_wait_signal", code)
109132

133+
def test_global_sync(self):
134+
@helion.kernel
135+
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
136+
M, N = signal_pad.shape
137+
assert M == N
138+
for i in hl.grid(N):
139+
for tile in hl.tile(N, block_size=N):
140+
hl.signal(signal_pad, [tile, i], signal=1, skip_sync=True)
141+
hl.wait(signal_pad, [i, tile], signal=1)
142+
return signal_pad
143+
144+
signal_pad = torch.zeros(4, 4, device=DEVICE, dtype=torch.int32)
145+
146+
code, result = code_and_output(gmem_multi_bar_sync_kernel, (signal_pad,))
147+
torch.testing.assert_close(
148+
result, torch.ones(4, 4, device=DEVICE, dtype=torch.int32)
149+
)
150+
self.assertExpectedJournal(code)
151+
110152

111153
if __name__ == "__main__":
112154
unittest.main()

0 commit comments

Comments
 (0)