Skip to content

Add hl.signal #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: joydddd/stack/5
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion helion/_triton_ext/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from .gmem_barrier import _triton_send_signal
from .gmem_barrier import _triton_wait_multiple_signal
from .gmem_barrier import _triton_wait_signal

__all__ = ["_triton_wait_multiple_signal", "_triton_wait_signal"]
__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"]
45 changes: 45 additions & 0 deletions helion/_triton_ext/gmem_barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,51 @@
import triton
import triton.language as tl

__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"]


@triton.jit
def _triton_send_signal(
addr, # can be a scalar or a vector of pointers.
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
) -> None:
"""
Send a signal to a global memory barrier.

This function implements a spin-wait loop that continuously checks a memory location
until it reaches the expected value, providing synchronization across GPU threads.

Args:
addr: Memory address of the barrier to wait on (Must be a scalar)
expect: Expected value to wait for
update: Update
"""
if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)

tl.static_assert(
sem == "release" or sem == "relaxed",
"Invalid memory semantic. options: 'release', 'relaxed'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
)

if op == "atomic_xchg":
tl.atomic_xchg(addr, update, sem=sem, scope=scope)
elif op == "atomic_add":
tl.atomic_add(addr, update, sem=sem, scope=scope)
else:
raise NotImplementedError(
f"Unsupported op '{op}' for send signal on gmem barrier. "
)


@triton.jit
def _triton_wait_signal(
Expand Down
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .memory_ops import atomic_add as atomic_add
from .memory_ops import load as load
from .memory_ops import store as store
from .signal_wait import signal as signal
from .signal_wait import wait as wait
from .tile_ops import tile_begin as tile_begin
from .tile_ops import tile_block_size as tile_block_size
Expand Down
101 changes: 101 additions & 0 deletions helion/language/signal_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
from torch.fx import has_side_effect

from .. import exc
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators

if TYPE_CHECKING:
import ast

from .._compiler.inductor_lowering import CodegenState
from helion._compiler.type_propagation import SymIntType

__all__ = ["signal", "wait"]


@has_side_effect
Expand Down Expand Up @@ -195,3 +199,100 @@ def _(state: CodegenState) -> ast.AST:
signal=signal_expr,
update=update_expr,
)


@has_side_effect
@_decorators.api(tiles_as_sizes=True)
def signal(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> torch.Tensor | SymIntType:
raise exc.NotInsideKernel


@_decorators.prepare_args(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> tuple[torch.Tensor, object, int, str, str, str, bool]:
from helion.language.tile_proxy import Tile

valid_ops = {"atomic_add", "atomic_xchg"}
valid_sems = {"relaxed", "release", "acq_rel"}
valid_scopes = {"sys", "gpu"}

if op not in valid_ops:
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")

if sem not in valid_sems:
raise ValueError(
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
)

if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

return (signal_pad, index, signal, op, sem, scope, skip_sync)


@_decorators.register_fake(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> torch.Tensor:
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))


@_decorators.codegen(signal)
def _(state: CodegenState) -> ast.AST:
import ast

from .._compiler.ast_extension import expr_from_string
from .._compiler.indexing_strategy import SubscriptIndexing

signal_pad = state.proxy_arg(0)
index = state.proxy_arg(1)
signal = state.proxy_arg(2)
op = state.proxy_arg(3)
sem = state.proxy_arg(4)
scope = state.proxy_arg(5)
skip_sync = state.proxy_arg(6)

assert isinstance(signal_pad, torch.Tensor)
assert isinstance(index, list)

indices = SubscriptIndexing.create(state, signal_pad, index)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name

signal_expr = ast.Constant(value=signal)
assert type(op) is str
assert type(sem) is str
assert type(scope) is str

hl_ext_call_triton_send_signal = f"hl_ext._triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"

# lock_spin_ptx = get_lock_spin_ptx(signal_pad_name, op, sem, scope)

return expr_from_string(
hl_ext_call_triton_send_signal,
offset=indices.index_expr,
signal=signal_expr,
)
53 changes: 53 additions & 0 deletions test/test_signal_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,59 @@ def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch
code,
)

def test_signal_basic(self):
@helion.kernel
def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
(n,) = signal_pad.shape
for i in hl.grid(n):
hl.signal(signal_pad, [i], signal=1)
return signal_pad

signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)
code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,))
torch.testing.assert_close(
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
)
self.assertIn("hl_ext._triton_send_signal", code)

def test_signal_multiple(self):
@helion.kernel
def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
(n,) = signal_pad.shape
for tile in hl.tile(n):
hl.signal(signal_pad, [tile], signal=1)
return signal_pad

signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32)
code, result = code_and_output(
gmem_signal_tensor_bar_kernel,
(signal_pad,),
block_size=[4],
)
torch.testing.assert_close(
result, torch.ones(16, device=DEVICE, dtype=torch.int32)
)
self.assertIn("hl_ext._triton_send_signal", code)

def test_sent_recieve_cta(self):
@helion.kernel
def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor:
(n,) = signal_pad.shape
for i in hl.grid(n): # first N ctas sends signal
hl.signal(signal_pad, [i], signal=1)
for i in hl.grid(n): # last N ctas waits for signal
hl.wait(signal_pad, [i], signal=1)
return signal_pad

signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32)

code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,))
torch.testing.assert_close(
result, torch.ones(4, device=DEVICE, dtype=torch.int32)
)
self.assertIn("hl_ext._triton_send_signal", code)
self.assertIn("hl_ext._triton_wait_signal", code)


if __name__ == "__main__":
unittest.main()
Loading