From 3844422247f09ddd5d8ac65e9dcaf5ce53ab62dd Mon Sep 17 00:00:00 2001 From: joydddd Date: Tue, 1 Jul 2025 16:00:30 -0700 Subject: [PATCH] Add hl.signal stack-info: PR: https://github.com/pytorch-labs/helion/pull/233, branch: joydddd/stack/8 --- helion/_triton_ext/__init__.py | 3 +- helion/_triton_ext/gmem_barrier.py | 45 ++++++++++++++ helion/language/__init__.py | 1 + helion/language/signal_wait.py | 99 ++++++++++++++++++++++++++++++ test/test_signal_wait.py | 53 ++++++++++++++++ 5 files changed, 200 insertions(+), 1 deletion(-) diff --git a/helion/_triton_ext/__init__.py b/helion/_triton_ext/__init__.py index 434125a0..86ac9ed2 100644 --- a/helion/_triton_ext/__init__.py +++ b/helion/_triton_ext/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations +from .gmem_barrier import _triton_send_signal from .gmem_barrier import _triton_wait_multiple_signal from .gmem_barrier import _triton_wait_signal -__all__ = ["_triton_wait_multiple_signal", "_triton_wait_signal"] +__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"] diff --git a/helion/_triton_ext/gmem_barrier.py b/helion/_triton_ext/gmem_barrier.py index a49fe3f5..8f0e9931 100644 --- a/helion/_triton_ext/gmem_barrier.py +++ b/helion/_triton_ext/gmem_barrier.py @@ -4,6 +4,51 @@ import triton import triton.language as tl +__all__ = ["_triton_send_signal", "_triton_wait_multiple_signal", "_triton_wait_signal"] + + +@triton.jit +def _triton_send_signal( + addr, # can be a scalar or a vector of pointers. + update: tl.constexpr, + sem: tl.constexpr, + scope: tl.constexpr, + op: tl.constexpr, + skip_sync: tl.constexpr, +) -> None: + """ + Send a signal to a global memory barrier. + + This function implements a spin-wait loop that continuously checks a memory location + until it reaches the expected value, providing synchronization across GPU threads. + + Args: + addr: Memory address of the barrier to wait on (Must be a scalar) + expect: Expected value to wait for + update: Update + """ + if not skip_sync: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + + tl.static_assert( + sem == "release" or sem == "relaxed", + "Invalid memory semantic. options: 'release', 'relaxed'. ", + ) + tl.static_assert( + scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. " + ) + + if op == "atomic_xchg": + tl.atomic_xchg(addr, update, sem=sem, scope=scope) + elif op == "atomic_add": + tl.atomic_add(addr, update, sem=sem, scope=scope) + else: + raise NotImplementedError( + f"Unsupported op '{op}' for send signal on gmem barrier. " + ) + @triton.jit def _triton_wait_signal( diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 277f1cc5..e38d6ad1 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -15,6 +15,7 @@ from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod from .scan_ops import cumsum as cumsum +from .signal_wait import signal as signal from .signal_wait import wait as wait from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index bd47c6d0..39cf6121 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -6,12 +6,16 @@ from torch.fx import has_side_effect from .. import exc +from .._compiler.indexing_strategy import SubscriptIndexing from . import _decorators if TYPE_CHECKING: import ast from .._compiler.inductor_lowering import CodegenState + from helion._compiler.type_propagation import SymIntType + +__all__ = ["signal", "wait"] @has_side_effect @@ -151,3 +155,98 @@ def _(state: CodegenState) -> ast.AST: signal=signal_expr, update=update_expr, ) + + +@has_side_effect +@_decorators.api(tiles_as_sizes=True) +def signal( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor | SymIntType: + raise exc.NotInsideKernel + + +@_decorators.prepare_args(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> tuple[torch.Tensor, object, int, str, str, str, bool]: + from helion.language.tile_proxy import Tile + + valid_ops = {"atomic_add", "atomic_xchg"} + valid_sems = {"relaxed", "release", "acq_rel"} + valid_scopes = {"sys", "gpu"} + + if op not in valid_ops: + raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ") + + if sem not in valid_sems: + raise ValueError( + f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." + ) + + if scope not in valid_scopes: + raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") + + index = Tile._prepare_index(index) + index = Tile._tiles_to_sizes(index) + + return (signal_pad, index, signal, op, sem, scope, skip_sync) + + +@_decorators.register_fake(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor: + return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index)) + + +@_decorators.codegen(signal) +def _(state: CodegenState) -> ast.AST: + import ast + + from .._compiler.ast_extension import expr_from_string + from .._compiler.indexing_strategy import SubscriptIndexing + + signal_pad = state.proxy_arg(0) + index = state.proxy_arg(1) + signal = state.proxy_arg(2) + op = state.proxy_arg(3) + sem = state.proxy_arg(4) + scope = state.proxy_arg(5) + skip_sync = state.proxy_arg(6) + + assert isinstance(signal_pad, torch.Tensor) + assert isinstance(index, list) + + indices = SubscriptIndexing.create(state, signal_pad, index) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + + signal_expr = ast.Constant(value=signal) + assert type(op) is str + assert type(sem) is str + assert type(scope) is str + + hl_ext_call_triton_send_signal = f"hl_ext._triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + + return expr_from_string( + hl_ext_call_triton_send_signal, + offset=indices.index_expr, + signal=signal_expr, + ) diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index 94e703b2..ae3ebf76 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -101,6 +101,59 @@ def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch code, ) + def test_signal_basic(self): + @helion.kernel + def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for i in hl.grid(n): + hl.signal(signal_pad, [i], signal=1) + return signal_pad + + signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) + code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertIn("hl_ext._triton_send_signal", code) + + def test_signal_multiple(self): + @helion.kernel + def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for tile in hl.tile(n): + hl.signal(signal_pad, [tile], signal=1) + return signal_pad + + signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + gmem_signal_tensor_bar_kernel, + (signal_pad,), + block_size=[4], + ) + torch.testing.assert_close( + result, torch.ones(16, device=DEVICE, dtype=torch.int32) + ) + self.assertIn("hl_ext._triton_send_signal", code) + + def test_sent_recieve_cta(self): + @helion.kernel + def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for i in hl.grid(n): # first N ctas sends signal + hl.signal(signal_pad, [i], signal=1) + for i in hl.grid(n): # last N ctas waits for signal + hl.wait(signal_pad, [i], signal=1) + return signal_pad + + signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) + + code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertIn("hl_ext._triton_send_signal", code) + self.assertIn("hl_ext._triton_wait_signal", code) + if __name__ == "__main__": unittest.main()