diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 277f1cc5..e38d6ad1 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -15,6 +15,7 @@ from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod from .scan_ops import cumsum as cumsum +from .signal_wait import signal as signal from .signal_wait import wait as wait from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 7d8b81cd..0d501b56 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -6,6 +6,7 @@ from torch.fx import has_side_effect from .. import exc +from .._compiler.indexing_strategy import SubscriptIndexing from . import _decorators if TYPE_CHECKING: @@ -13,6 +14,8 @@ from .._compiler.inductor_lowering import CodegenState +__all__ = ["signal", "wait"] + @has_side_effect @_decorators.api(tiles_as_sizes=True) @@ -146,3 +149,143 @@ def _(state: CodegenState) -> ast.AST: signal=signal_expr, update=update_expr, ) + + +@has_side_effect +@_decorators.api(tiles_as_sizes=True) +def signal( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + wait_for: int | None = None, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor: + """Set the signal_pad slice to the signal value. + Args: + signal_pad: The signal pad to signal + index: Indices to index into the signal_pad tensor + signal: the value to send + wait_for: The value to wait for before sending the signal. Only valid for op = 'atomic_cas'. + op: The memory op for acquring the lock (default: 'atomic_xchg') + sem: The memory sematic for acquring the lock (default: 'release') + scope: The scope of the lock (default: 'gpu') + skip_sync: Skip the syncthreads before sending signal (default: False) + """ + raise exc.NotInsideKernel + + +@_decorators.prepare_args(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + wait_for: int | None = None, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + from helion.language.tile_proxy import Tile + + valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"} + valid_sems = {"relaxed", "release", "acq_rel"} + valid_scopes = {"sys", "gpu"} + + if op not in valid_ops: + raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ") + + if op == "atomic_cas" and wait_for is None: + raise ValueError( + f"{op} without a wait_for value. Do you want to use 'atomic_add' or 'atomic_xchg' instead? " + ) + if op in {"atomic_add", "atomic_xchg"} and wait_for is not None: + raise ValueError( + f"{op} with a wait_for value. Do you want to use 'atomic_cas' instead? " + ) + + if sem not in valid_sems: + raise ValueError( + f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." + ) + + if scope not in valid_scopes: + raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") + + index = Tile._prepare_index(index) + index = Tile._tiles_to_sizes(index) + + return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync) + + +@_decorators.register_fake(signal) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + wait_for: int | None = None, + op: str = "atomic_xchg", + sem: str = "release", + scope: str = "gpu", + skip_sync: bool = False, +) -> torch.Tensor: + return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index)) + + +@_decorators.codegen(signal) +def _(state: CodegenState) -> ast.AST: + import ast + + from .._compiler.ast_extension import expr_from_string + from .._compiler.indexing_strategy import SubscriptIndexing + + signal_pad = state.proxy_arg(0) + index = state.proxy_arg(1) + signal = state.proxy_arg(2) + wait_for = state.proxy_arg(3) + op = state.proxy_arg(4) + sem = state.proxy_arg(5) + scope = state.proxy_arg(6) + skip_sync = state.proxy_arg(7) + + assert isinstance(signal_pad, torch.Tensor) + assert isinstance(index, list) + + indices = SubscriptIndexing.create(state, signal_pad, index) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + + signal_expr = ast.Constant(value=signal) + if wait_for is not None: + wait_for_expr = ast.Constant(value=wait_for) + else: + wait_for_expr = ast.Constant(value=0) + skip_sync_expr = ast.Constant(value=skip_sync) + assert type(op) is str + assert type(sem) is str + assert type(scope) is str + + if op == "atomic_cas": + bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) + is_scalar = len(bar_tensor_shape) == 0 + if is_scalar: + call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" + else: + call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)" + + return expr_from_string( + call_triton_wait_signal, + offset=indices.index_expr, + wait_for=wait_for_expr, + signal=signal_expr, + skip_sync=skip_sync_expr, + ) + call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)" + + return expr_from_string( + call_triton_send_signal, + offset=indices.index_expr, + signal=signal_expr, + skip_sync=skip_sync_expr, + ) diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 4b7259fb..0bd3d2b8 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -8,6 +8,7 @@ from .config import Config as Config from .kernel import Kernel as Kernel from .kernel import kernel as kernel +from .triton_helpers import triton_send_signal as triton_send_signal from .triton_helpers import triton_wait_signal as triton_wait_signal diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index a96e2d41..0f1fc19a 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -3,7 +3,56 @@ import triton import triton.language as tl -__all__ = ["triton_wait_signal"] +__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"] + + +@triton.jit +def triton_send_signal( + addr: tl.tensor, + update: tl.constexpr, + sem: tl.constexpr, + scope: tl.constexpr, + op: tl.constexpr, + skip_sync: tl.constexpr, +) -> tl.tensor: + """ + Signal global memory barrier(s). + + This function atomically sets global memory barriers to a update value, + signaling to other CTAs waiting on the barrier(s). + + Args: + addr: Memory address of the barrier(s) to wait on + update: Set the barrier to + sem: Memory semantics for the atomic operation. Options: "release", "relaxed". + scope: Scope of the atomic operation. Options: "gpu", "sys" + op: Atomic operation type: "atomic_xchg", "atomic_add" + skip_sync: Skip CTA synchronization before setting the barrier. (default: False) + Returns: + The old value of the barrier(s) before the update. + """ + if not skip_sync: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + + tl.static_assert( + sem == "release" or sem == "relaxed", + "Invalid memory semantic. options: 'release', 'relaxed'. ", + ) + tl.static_assert( + scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. " + ) + + if op == "atomic_xchg": + barrier_status = tl.atomic_xchg(addr, update, sem=sem, scope=scope) + elif op == "atomic_add": + barrier_status = tl.atomic_add(addr, update, sem=sem, scope=scope) + else: + raise NotImplementedError( + f"Unsupported op '{op}' for send signal on gmem barrier. " + ) + return barrier_status @triton.jit @@ -15,6 +64,7 @@ def triton_wait_signal( scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, + sync_before: tl.constexpr = False, # pyre-ignore[9] ) -> None: """ Wait for a global memory barrier to reach the expected value. @@ -30,6 +80,7 @@ def triton_wait_signal( scope: Scope of the atomic operation. Options: "gpu", "sys" op: Atomic operation type: "ld", "atomic_cas" skip_sync: Skip CTA sync after acquiring the barrier (default: False) + sync_before: Add a CTA sync before the wait (default: False) """ tl.static_assert( addr.type.is_ptr(), @@ -37,8 +88,8 @@ def triton_wait_signal( ) tl.static_assert( - sem == "acquire" or sem == "relaxed", - "Invalid memory semantic. options: 'acquire', 'relaxed'. ", + (sem == "acquire" or sem == "relaxed") or sem == "release", + "Invalid memory semantic. options: 'acquire', 'relaxed', 'release'. ", ) tl.static_assert( scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. " @@ -48,6 +99,11 @@ def triton_wait_signal( "Invalid op. options: 'ld', 'atomic_cas'. ", ) + if sync_before: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + # Spin-wait loop: # Uses atomic_add with update=0 for ld.global.{sem}.{scope} # Triton generates smem broadcasting of tl.atomic_add return value in ptx, @@ -71,3 +127,18 @@ def triton_wait_signal( "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 ) # tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?) + + +@triton.jit +def triton_wait_multiple_signal( + addr: tl.tensor, + expect: tl.constexpr, # wait until lock is set to expect + update: tl.constexpr, # update the lock once it is aquired. + sem: tl.constexpr, + scope: tl.constexpr, + op: tl.constexpr, + skip_sync: tl.constexpr, + sync_before: tl.constexpr = False, # pyre-ignore[9] +) -> None: + raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ") + # TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected index f9f30a22..473ef2af 100644 --- a/test/test_signal_wait.expected +++ b/test/test_signal_wait.expected @@ -1,6 +1,81 @@ This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestWait.test_signal_basic) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_signal_scalar_bar_kernel_kernel(signal_pad, signal_pad_stride_0): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + helion.runtime.triton_send_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False) + +def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor): + n, = signal_pad.shape + _gmem_signal_scalar_bar_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3) + return signal_pad + +def _gmem_signal_scalar_bar_kernel_make_precompiler(signal_pad: torch.Tensor): + n, = signal_pad.shape + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_signal_scalar_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestWait.test_signal_cas) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_signal_cas_kernel_kernel(signal_pad, signal_pad_stride_0): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + helion.runtime.triton_wait_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False) + +def gmem_signal_cas_kernel(signal_pad: torch.Tensor): + n, = signal_pad.shape + _gmem_signal_cas_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3) + return signal_pad + +def _gmem_signal_cas_kernel_make_precompiler(signal_pad: torch.Tensor): + n, = signal_pad.shape + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_signal_cas_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestWait.test_signal_multiple) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + helion.runtime.triton_send_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False) + +def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor): + n, = signal_pad.shape + _BLOCK_SIZE_0 = 4 + _gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return signal_pad + +def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor): + n, = signal_pad.shape + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + --- assertExpectedJournal(TestWait.test_wait_2d_tile) from __future__ import annotations diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index fbb5ed36..720e1530 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -54,6 +54,74 @@ def wait_for_2d_tile_kernel( torch.testing.assert_close(result, x) self.assertExpectedJournal(code) + def test_signal_basic(self): + @helion.kernel + def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for i in hl.grid(n): + hl.signal(signal_pad, [i], signal=1) + return signal_pad + + signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) + code, result = code_and_output(gmem_signal_scalar_bar_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_signal_cas(self): + @helion.kernel + def gmem_signal_cas_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for i in hl.grid(n): + hl.signal(signal_pad, [i], signal=1, wait_for=0, op="atomic_cas") + return signal_pad + + signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) + code, result = code_and_output(gmem_signal_cas_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_signal_multiple(self): + @helion.kernel + def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for tile in hl.tile(n): + hl.signal(signal_pad, [tile], signal=1) + return signal_pad + + signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + gmem_signal_tensor_bar_kernel, + (signal_pad,), + block_size=[4], + ) + torch.testing.assert_close( + result, torch.ones(16, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_sent_recieve_cta(self): + @helion.kernel + def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for i in hl.grid(n): # first N ctas sends signal + hl.signal(signal_pad, [i], signal=1) + for i in hl.grid(n): # last N ctas waits for signal + hl.wait(signal_pad, [i], signal=1) + return signal_pad + + signal_pad = torch.zeros(4, device=DEVICE, dtype=torch.int32) + + code, result = code_and_output(gmem_signal_n_wait_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertIn("helion.runtime.triton_send_signal", code) + self.assertIn("helion.runtime.triton_wait_signal", code) + if __name__ == "__main__": unittest.main()