Skip to content

Commit 0eb9276

Browse files
committed
Add hl.wait for simultenous waiting for multiple gmem barriers
stack-info: PR: #243, branch: joydddd/stack/11
1 parent 60c19b9 commit 0eb9276

File tree

3 files changed

+155
-9
lines changed

3 files changed

+155
-9
lines changed

helion/_triton_ext/gmem_barrier.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,93 @@ def _triton_wait_signal(
118118

119119
@triton.jit
120120
def _triton_wait_multiple_signal(
121-
addr,
121+
addr: tl.tensor,
122122
expect: tl.constexpr, # wait until lock is set to expect
123123
update: tl.constexpr, # update the lock once it is aquired.
124124
sem: tl.constexpr,
125125
scope: tl.constexpr,
126126
op: tl.constexpr,
127127
skip_sync: tl.constexpr,
128128
) -> None:
129-
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
130-
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
129+
tl.static_assert(
130+
sem == "acquire" or sem == "relaxed",
131+
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
132+
)
133+
tl.static_assert(
134+
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
135+
)
136+
tl.static_assert(
137+
op == "ld" or op == "atomic_cas",
138+
"Invalid op. options: 'ld', 'atomic_cas'. ",
139+
)
140+
141+
tl.static_assert(
142+
addr.dtype == tl.pointer_type(tl.int32),
143+
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
144+
)
145+
146+
addr = tl.ravel(addr)
147+
148+
tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
149+
tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ")
150+
151+
# Assume Triton always sets tid.y == tid.z == 0.
152+
if op == "ld":
153+
tl.inline_asm_elementwise(
154+
f"""
155+
{{
156+
.reg .u32 %tmp32_<3>;
157+
.reg .pred %p<2>;
158+
159+
mov.u32 %tmp32_0, %tid.x;
160+
setp.le.s32 %p1, %tmp32_0, $2;
161+
162+
mov.u32 $0, 0;
163+
// initialize tmp_0 to 0
164+
wait_block:
165+
@%p1 ld.global.{sem}.{scope}.u32 $0, [$1];
166+
setp.ne.u32 %p0, $0, $3;
167+
and.pred %p0, %p0, %p1;
168+
@%p0 bra wait_block;
169+
}}
170+
""",
171+
"=r, l, r, r",
172+
[addr, addr.shape[0], expect],
173+
dtype=addr.dtype.element_ty,
174+
is_pure=False,
175+
pack=1,
176+
)
177+
elif op == "atomic_cas":
178+
tl.inline_asm_elementwise(
179+
f"""
180+
{{
181+
.reg .u32 %tmp32_<3>;
182+
.reg .pred %p<2>;
183+
184+
mov.u32 %tmp32_0, %tid.x;
185+
setp.le.s32 %p1, %tmp32_0, $2;
186+
187+
mov.u32 $0, 0;
188+
// initialize tmp_0 to 0
189+
wait_block:
190+
@%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4;
191+
setp.ne.u32 %p0, $0, $3;
192+
and.pred %p0, %p0, %p1;
193+
@%p0 bra wait_block;
194+
}}
195+
""",
196+
"=r, l, r, r, r",
197+
[addr, addr.shape[0], expect, update],
198+
dtype=addr.dtype.element_ty,
199+
is_pure=False,
200+
pack=1,
201+
)
202+
else:
203+
raise NotImplementedError(
204+
f"Unsupported op '{op}' for wait signal on gmem barrier. "
205+
)
206+
207+
if not skip_sync:
208+
tl.inline_asm_elementwise(
209+
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
210+
)

helion/language/signal_wait.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ def _(
9494
if scope not in valid_scopes:
9595
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
9696

97-
# TODO(joydddd): add support for non scalar index into signal_pad
98-
for i in index:
99-
assert isinstance(i, int | torch.SymInt)
100-
10197
index = Tile._prepare_index(index)
10298
index = Tile._tiles_to_sizes(index)
10399

@@ -147,7 +143,17 @@ def _(state: CodegenState) -> ast.AST:
147143
assert type(sem) is str
148144
assert type(scope) is str
149145

150-
hl_ext_call_triton_wait_signal = f"hl_ext._triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
146+
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
147+
is_scalar = len(bar_tensor_shape) == 0
148+
149+
if is_scalar:
150+
hl_ext_call_triton_wait_signal = f"hl_ext._triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
151+
else:
152+
if signal_pad.dtype not in (torch.int32, torch.uint32):
153+
raise NotImplementedError(
154+
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
155+
)
156+
hl_ext_call_triton_wait_signal = f"hl_ext._triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
151157

152158
return expr_from_string(
153159
hl_ext_call_triton_wait_signal,

test/test_signal_wait.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,36 @@ def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch
101101
code,
102102
)
103103

104+
def test_wait_multi_bar(self):
105+
@helion.kernel
106+
def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
107+
(N,) = signal_pad.shape
108+
n = hl.register_block_size(N)
109+
out = torch.empty(n, dtype=torch.int32, device=DEVICE)
110+
111+
for tile in hl.tile(N, block_size=n):
112+
hl.wait(signal_pad, [tile], signal=1)
113+
out[tile.id] = tile.id
114+
115+
return out
116+
117+
signal_pad = torch.ones(16, device=DEVICE, dtype=torch.int32)
118+
code, result = code_and_output(
119+
gmem_wait_multi_bar_kernel, (signal_pad,), block_size=[4]
120+
)
121+
torch.testing.assert_close(
122+
result, torch.arange(4, device=DEVICE, dtype=torch.int32)
123+
)
124+
self.maxDiff = None
125+
self.assertIn(
126+
"""\
127+
pid_0 = tl.program_id(0)
128+
offset_0 = pid_0 * _BLOCK_SIZE_0
129+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
130+
hl_ext._triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)""",
131+
code,
132+
)
133+
104134
def test_signal_basic(self):
105135
@helion.kernel
106136
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
@@ -135,7 +165,7 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
135165
)
136166
self.assertIn("hl_ext._triton_send_signal", code)
137167

138-
def test_sent_recieve_cta(self):
168+
def test_send_recieve_cta(self):
139169
@helion.kernel
140170
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
141171
(n,) = signal_pad.shape
@@ -154,6 +184,36 @@ def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
154184
self.assertIn("hl_ext._triton_send_signal", code)
155185
self.assertIn("hl_ext._triton_wait_signal", code)
156186

187+
def test_global_sync(self):
188+
@helion.kernel
189+
def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
190+
M, N = signal_pad.shape
191+
assert M == N
192+
for i in hl.grid(N):
193+
for tile in hl.tile(N, block_size=N):
194+
hl.signal(signal_pad, [tile, i], signal=1, skip_sync=True)
195+
hl.wait(signal_pad, [i, tile], signal=1)
196+
return signal_pad
197+
198+
signal_pad = torch.zeros(4, 4, device=DEVICE, dtype=torch.int32)
199+
200+
code, result = code_and_output(gmem_multi_bar_sync_kernel, (signal_pad,))
201+
torch.testing.assert_close(
202+
result, torch.ones(4, 4, device=DEVICE, dtype=torch.int32)
203+
)
204+
self.assertIn(
205+
"""
206+
@triton.jit
207+
def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr):
208+
pid_0 = tl.program_id(0)
209+
offset_0 = pid_0
210+
for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1):
211+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
212+
hl_ext._triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True)
213+
hl_ext._triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)""",
214+
code,
215+
)
216+
157217

158218
if __name__ == "__main__":
159219
unittest.main()

0 commit comments

Comments
 (0)