Skip to content

Add hl.wait for simultenous waiting for multiple gmem barriers #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions helion/language/signal_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _(
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
)

if op == "atomic_cas" and not update:
if op == "atomic_cas" and update is None:
raise ValueError(
f"{op} without an update value. Do you want to use 'ld' instead? "
)
Expand All @@ -88,10 +88,6 @@ def _(
if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

# TODO(joydddd): add support for non scalar index into signal_pad
for i in index:
assert isinstance(i, int | torch.SymInt)

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

Expand Down Expand Up @@ -141,7 +137,17 @@ def _(state: CodegenState) -> ast.AST:
assert type(sem) is str
assert type(scope) is str

call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
is_scalar = len(bar_tensor_shape) == 0

if is_scalar:
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
else:
if signal_pad.dtype not in (torch.int32, torch.uint32):
raise NotImplementedError(
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
)
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"

return expr_from_string(
call_triton_wait_signal,
Expand Down Expand Up @@ -272,7 +278,7 @@ def _(state: CodegenState) -> ast.AST:
if is_scalar:
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
else:
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)"
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"

return expr_from_string(
call_triton_wait_signal,
Expand Down
1 change: 1 addition & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
from .triton_helpers import triton_send_signal as triton_send_signal
from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal
from .triton_helpers import triton_wait_signal as triton_wait_signal


Expand Down
102 changes: 98 additions & 4 deletions helion/runtime/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,107 @@ def triton_wait_signal(
@triton.jit
def triton_wait_multiple_signal(
addr: tl.tensor,
expect: tl.constexpr, # wait until lock is set to expect
update: tl.constexpr, # update the lock once it is aquired.
expect: tl.constexpr,
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
sync_before: tl.constexpr = False, # pyre-ignore[9]
) -> None:
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
"""
Simultenuoslly wait for multiple global memory barrier to reach the expected value.

This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs.

Args:
addr: Memory addresses of the barriers to wait on (Maximum 32 barriers)
expect: Expected value to wait for
update: Update the barrier with once acquired
sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False)
"""
tl.static_assert(
(sem == "acquire" or sem == "relaxed") or sem == "release",
"Invalid memory semantic. options: 'acquire', 'relaxed' 'release'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
)
tl.static_assert(
op == "ld" or op == "atomic_cas",
"Invalid op. options: 'ld', 'atomic_cas'. ",
)

tl.static_assert(
addr.dtype == tl.pointer_type(tl.int32),
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
)

addr = tl.ravel(addr)

tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ")

# Assume Triton always sets tid.y == tid.z == 0.
if op == "ld":
tl.inline_asm_elementwise(
f"""
{{
.reg .u32 %tmp32_<3>;
.reg .pred %p<2>;

mov.u32 %tmp32_0, %tid.x;
setp.lt.s32 %p1, %tmp32_0, $2;

mov.u32 $0, 0;
// initialize tmp_0 to 0
wait_block:
@%p1 ld.global.{sem}.{scope}.u32 $0, [$1];
setp.ne.u32 %p0, $0, $3;
and.pred %p0, %p0, %p1;
@%p0 bra wait_block;
}}
""",
"=r, l, r, r",
[addr, addr.shape[0], expect],
dtype=addr.dtype.element_ty,
is_pure=False,
pack=1,
)
elif op == "atomic_cas":
tl.inline_asm_elementwise(
f"""
{{
.reg .u32 %tmp32_<3>;
.reg .pred %p<2>;

mov.u32 %tmp32_0, %tid.x;
setp.lt.s32 %p1, %tmp32_0, $2;

mov.u32 $0, 0;
// initialize tmp_0 to 0
wait_block:
@%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4;
setp.ne.u32 %p0, $0, $3;
and.pred %p0, %p0, %p1;
@%p0 bra wait_block;
}}
""",
"=r, l, r, r, r",
[addr, addr.shape[0], expect, update],
dtype=addr.dtype.element_ty,
is_pure=False,
pack=1,
)
else:
raise NotImplementedError(
f"Unsupported op '{op}' for wait signal on gmem barrier. "
)

if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
121 changes: 121 additions & 0 deletions test/test_signal_wait.expected
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestWait.test_global_sync)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True)
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)

def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor):
M, N = signal_pad.shape
assert M == N
_BLOCK_SIZE_1 = N
_gmem_multi_bar_sync_kernel_kernel[N,](signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return signal_pad

def _gmem_multi_bar_sync_kernel_make_precompiler(signal_pad: torch.Tensor):
M, N = signal_pad.shape
assert M == N
_BLOCK_SIZE_1 = N
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_multi_bar_sync_kernel_kernel)(signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_signal_basic)
from __future__ import annotations

Expand Down Expand Up @@ -76,6 +107,33 @@ def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_signal_multiple_cas)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False)

def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
n, = signal_pad.shape
_BLOCK_SIZE_0 = 4
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return signal_pad

def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
n, = signal_pad.shape
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_wait_2d_tile)
from __future__ import annotations

Expand Down Expand Up @@ -144,3 +202,66 @@ def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor):
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_wait_multi_bar)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

import test.test_signal_wait as _source_module

@triton.jit
def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
tile_id = offset_0 // _BLOCK_SIZE_0
tl.store(out + tile_id * out_stride_0, tile_id, None)

def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor):
N, = signal_pad.shape
n = 4
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
_BLOCK_SIZE_0 = 4
_gmem_wait_multi_bar_kernel_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def _gmem_wait_multi_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
N, = signal_pad.shape
n = 4
out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE)
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_wait_multi_bar_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_wait_multi_bar_cas)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_wait_multi_bar_kernel_cas_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=2, sem='acquire', scope='gpu', op='atomic_cas', skip_sync=False)

def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor):
N, = signal_pad.shape
n = 4
_BLOCK_SIZE_0 = 4
_gmem_wait_multi_bar_kernel_cas_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return signal_pad

def _gmem_wait_multi_bar_kernel_cas_make_precompiler(signal_pad: torch.Tensor):
N, = signal_pad.shape
n = 4
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_wait_multi_bar_kernel_cas_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
Loading
Loading