Skip to content

Commit dc22b98

Browse files
committed
Add hl.wait for simultenous waiting for multiple gmem barriers
stack-info: PR: #243, branch: joydddd/stack/11
1 parent 52f0183 commit dc22b98

File tree

5 files changed

+342
-12
lines changed

5 files changed

+342
-12
lines changed

helion/language/signal_wait.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _(
7676
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
7777
)
7878

79-
if op == "atomic_cas" and not update:
79+
if op == "atomic_cas" and update is None:
8080
raise ValueError(
8181
f"{op} without an update value. Do you want to use 'ld' instead? "
8282
)
@@ -88,10 +88,6 @@ def _(
8888
if scope not in valid_scopes:
8989
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
9090

91-
# TODO(joydddd): add support for non scalar index into signal_pad
92-
for i in index:
93-
assert isinstance(i, int | torch.SymInt)
94-
9591
index = Tile._prepare_index(index)
9692
index = Tile._tiles_to_sizes(index)
9793

@@ -141,7 +137,17 @@ def _(state: CodegenState) -> ast.AST:
141137
assert type(sem) is str
142138
assert type(scope) is str
143139

144-
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
140+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
141+
is_scalar = len(bar_tensor_shape) == 0
142+
143+
if is_scalar:
144+
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
145+
else:
146+
if signal_pad.dtype not in (torch.int32, torch.uint32):
147+
raise NotImplementedError(
148+
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
149+
)
150+
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
145151

146152
return expr_from_string(
147153
call_triton_wait_signal,
@@ -272,7 +278,7 @@ def _(state: CodegenState) -> ast.AST:
272278
if is_scalar:
273279
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
274280
else:
275-
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)"
281+
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
276282

277283
return expr_from_string(
278284
call_triton_wait_signal,

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
1111
from .triton_helpers import triton_send_signal as triton_send_signal
12+
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
1213
from .triton_helpers import triton_wait_signal as triton_wait_signal
1314

1415

helion/runtime/triton_helpers.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,107 @@ def triton_wait_signal(
132132
@triton.jit
133133
def triton_wait_multiple_signal(
134134
addr: tl.tensor,
135-
expect: tl.constexpr, # wait until lock is set to expect
136-
update: tl.constexpr, # update the lock once it is aquired.
135+
expect: tl.constexpr,
136+
update: tl.constexpr,
137137
sem: tl.constexpr,
138138
scope: tl.constexpr,
139139
op: tl.constexpr,
140140
skip_sync: tl.constexpr,
141141
sync_before: tl.constexpr = False, # pyre-ignore[9]
142142
) -> None:
143-
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
144-
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
143+
"""
144+
Simultenuoslly wait for multiple global memory barrier to reach the expected value.
145+
146+
This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.
147+
148+
Args:
149+
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
150+
expect: Expected value to wait for
151+
update: Update the barrier with once acquired
152+
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
153+
scope: Scope of the atomic operation. Options: "gpu", "sys"
154+
op: Atomic operation type: "ld", "atomic_cas"
155+
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
156+
"""
157+
tl.static_assert(
158+
(sem == "acquire" or sem == "relaxed") or sem == "release",
159+
"Invalid memory semantic. options: 'acquire', 'relaxed' 'release'. ",
160+
)
161+
tl.static_assert(
162+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
163+
)
164+
tl.static_assert(
165+
op == "ld" or op == "atomic_cas",
166+
"Invalid op. options: 'ld', 'atomic_cas'. ",
167+
)
168+
169+
tl.static_assert(
170+
addr.dtype == tl.pointer_type(tl.int32),
171+
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
172+
)
173+
174+
addr = tl.ravel(addr)
175+
176+
tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
177+
tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ")
178+
179+
# Assume Triton always sets tid.y == tid.z == 0.
180+
if op == "ld":
181+
tl.inline_asm_elementwise(
182+
f"""
183+
{{
184+
.reg .u32 %tmp32_<3>;
185+
.reg .pred %p<2>;
186+
187+
mov.u32 %tmp32_0, %tid.x;
188+
setp.lt.s32 %p1, %tmp32_0, $2;
189+
190+
mov.u32 $0, 0;
191+
// initialize tmp_0 to 0
192+
wait_block:
193+
@%p1 ld.global.{sem}.{scope}.u32 $0, [$1];
194+
setp.ne.u32 %p0, $0, $3;
195+
and.pred %p0, %p0, %p1;
196+
@%p0 bra wait_block;
197+
}}
198+
""",
199+
"=r, l, r, r",
200+
[addr, addr.shape[0], expect],
201+
dtype=addr.dtype.element_ty,
202+
is_pure=False,
203+
pack=1,
204+
)
205+
elif op == "atomic_cas":
206+
tl.inline_asm_elementwise(
207+
f"""
208+
{{
209+
.reg .u32 %tmp32_<3>;
210+
.reg .pred %p<2>;
211+
212+
mov.u32 %tmp32_0, %tid.x;
213+
setp.lt.s32 %p1, %tmp32_0, $2;
214+
215+
mov.u32 $0, 0;
216+
// initialize tmp_0 to 0
217+
wait_block:
218+
@%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4;
219+
setp.ne.u32 %p0, $0, $3;
220+
and.pred %p0, %p0, %p1;
221+
@%p0 bra wait_block;
222+
}}
223+
""",
224+
"=r, l, r, r, r",
225+
[addr, addr.shape[0], expect, update],
226+
dtype=addr.dtype.element_ty,
227+
is_pure=False,
228+
pack=1,
229+
)
230+
else:
231+
raise NotImplementedError(
232+
f"Unsupported op '{op}' for wait signal on gmem barrier. "
233+
)
234+
235+
if not skip_sync:
236+
tl.inline_asm_elementwise(
237+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
238+
)

test/test_signal_wait.expected

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,37 @@
11
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

4+
--- assertExpectedJournal(TestWait.test_global_sync)
5+
from __future__ import annotations
6+
7+
import torch
8+
import helion
9+
import triton
10+
import triton.language as tl
11+
12+
@triton.jit
13+
def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr):
14+
pid_0 = tl.program_id(0)
15+
offset_0 = pid_0
16+
for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1):
17+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
18+
helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True)
19+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
20+
21+
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor):
22+
M, N = signal_pad.shape
23+
assert M == N
24+
_BLOCK_SIZE_1 = N
25+
_gmem_multi_bar_sync_kernel_kernel[N,](signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
26+
return signal_pad
27+
28+
def _gmem_multi_bar_sync_kernel_make_precompiler(signal_pad: torch.Tensor):
29+
M, N = signal_pad.shape
30+
assert M == N
31+
_BLOCK_SIZE_1 = N
32+
from helion.runtime.precompile_shim import make_precompiler
33+
return make_precompiler(_gmem_multi_bar_sync_kernel_kernel)(signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
34+
435
--- assertExpectedJournal(TestWait.test_signal_basic)
536
from __future__ import annotations
637

@@ -76,6 +107,33 @@ def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
76107
from helion.runtime.precompile_shim import make_precompiler
77108
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
78109

110+
--- assertExpectedJournal(TestWait.test_signal_multiple_cas)
111+
from __future__ import annotations
112+
113+
import torch
114+
import helion
115+
import triton
116+
import triton.language as tl
117+
118+
@triton.jit
119+
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
120+
pid_0 = tl.program_id(0)
121+
offset_0 = pid_0 * _BLOCK_SIZE_0
122+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
123+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False)
124+
125+
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
126+
n, = signal_pad.shape
127+
_BLOCK_SIZE_0 = 4
128+
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
129+
return signal_pad
130+
131+
def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
132+
n, = signal_pad.shape
133+
_BLOCK_SIZE_0 = 4
134+
from helion.runtime.precompile_shim import make_precompiler
135+
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
136+
79137
--- assertExpectedJournal(TestWait.test_wait_2d_tile)
80138
from __future__ import annotations
81139

@@ -144,3 +202,66 @@ def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor):
144202
from helion.runtime.precompile_shim import make_precompiler
145203
return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3)
146204

205+
--- assertExpectedJournal(TestWait.test_wait_multi_bar)
206+
from __future__ import annotations
207+
208+
import torch
209+
import helion
210+
import triton
211+
import triton.language as tl
212+
213+
import test.test_signal_wait as _source_module
214+
215+
@triton.jit
216+
def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
217+
pid_0 = tl.program_id(0)
218+
offset_0 = pid_0 * _BLOCK_SIZE_0
219+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
220+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
221+
tile_id = offset_0 // _BLOCK_SIZE_0
222+
tl.store(out + tile_id * out_stride_0, tile_id, None)
223+
224+
def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor):
225+
N, = signal_pad.shape
226+
n = 4
227+
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
228+
_BLOCK_SIZE_0 = 4
229+
_gmem_wait_multi_bar_kernel_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
230+
return out
231+
232+
def _gmem_wait_multi_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
233+
N, = signal_pad.shape
234+
n = 4
235+
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
236+
_BLOCK_SIZE_0 = 4
237+
from helion.runtime.precompile_shim import make_precompiler
238+
return make_precompiler(_gmem_wait_multi_bar_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
239+
240+
--- assertExpectedJournal(TestWait.test_wait_multi_bar_cas)
241+
from __future__ import annotations
242+
243+
import torch
244+
import helion
245+
import triton
246+
import triton.language as tl
247+
248+
@triton.jit
249+
def _gmem_wait_multi_bar_kernel_cas_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
250+
pid_0 = tl.program_id(0)
251+
offset_0 = pid_0 * _BLOCK_SIZE_0
252+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
253+
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=2, sem='acquire', scope='gpu', op='atomic_cas', skip_sync=False)
254+
255+
def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor):
256+
N, = signal_pad.shape
257+
n = 4
258+
_BLOCK_SIZE_0 = 4
259+
_gmem_wait_multi_bar_kernel_cas_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
260+
return signal_pad
261+
262+
def _gmem_wait_multi_bar_kernel_cas_make_precompiler(signal_pad: torch.Tensor):
263+
N, = signal_pad.shape
264+
n = 4
265+
_BLOCK_SIZE_0 = 4
266+
from helion.runtime.precompile_shim import make_precompiler
267+
return make_precompiler(_gmem_wait_multi_bar_kernel_cas_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

0 commit comments

Comments
 (0)