diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 0d501b56..8e6810d5 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -76,7 +76,7 @@ def _( f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." ) - if op == "atomic_cas" and not update: + if op == "atomic_cas" and update is None: raise ValueError( f"{op} without an update value. Do you want to use 'ld' instead? " ) @@ -88,10 +88,6 @@ def _( if scope not in valid_scopes: raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") - # TODO(joydddd): add support for non scalar index into signal_pad - for i in index: - assert isinstance(i, int | torch.SymInt) - index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) @@ -141,7 +137,17 @@ def _(state: CodegenState) -> ast.AST: assert type(sem) is str assert type(scope) is str - call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) + is_scalar = len(bar_tensor_shape) == 0 + + if is_scalar: + call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + else: + if signal_pad.dtype not in (torch.int32, torch.uint32): + raise NotImplementedError( + f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." + ) + call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" return expr_from_string( call_triton_wait_signal, @@ -272,7 +278,7 @@ def _(state: CodegenState) -> ast.AST: if is_scalar: call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" else: - call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)" + call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" return expr_from_string( call_triton_wait_signal, diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 0bd3d2b8..cb6b1aaf 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -9,6 +9,7 @@ from .kernel import Kernel as Kernel from .kernel import kernel as kernel from .triton_helpers import triton_send_signal as triton_send_signal +from .triton_helpers import triton_wait_multiple_signal as triton_wait_multiple_signal from .triton_helpers import triton_wait_signal as triton_wait_signal diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 0f1fc19a..b65005c4 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -132,13 +132,107 @@ def triton_wait_signal( @triton.jit def triton_wait_multiple_signal( addr: tl.tensor, - expect: tl.constexpr, # wait until lock is set to expect - update: tl.constexpr, # update the lock once it is aquired. + expect: tl.constexpr, + update: tl.constexpr, sem: tl.constexpr, scope: tl.constexpr, op: tl.constexpr, skip_sync: tl.constexpr, sync_before: tl.constexpr = False, # pyre-ignore[9] ) -> None: - raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ") - # TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier + """ + Simultenuoslly wait for multiple global memory barrier to reach the expected value. + + This function implements each thread in a CTA spin-waits and continuously checks a memory location until it reaches the expected value, providing synchronization across CTAs. + + Args: + addr: Memory addresses of the barriers to wait on (Maximum 32 barriers) + expect: Expected value to wait for + update: Update the barrier with once acquired + sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed". + scope: Scope of the atomic operation. Options: "gpu", "sys" + op: Atomic operation type: "ld", "atomic_cas" + skip_sync: Skip CTA synchronization after acquiring the barrier. (default: False) + """ + tl.static_assert( + (sem == "acquire" or sem == "relaxed") or sem == "release", + "Invalid memory semantic. options: 'acquire', 'relaxed' 'release'. ", + ) + tl.static_assert( + scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. " + ) + tl.static_assert( + op == "ld" or op == "atomic_cas", + "Invalid op. options: 'ld', 'atomic_cas'. ", + ) + + tl.static_assert( + addr.dtype == tl.pointer_type(tl.int32), + "Invalid barrier value type. Only supports int32 for multi barrier signal. ", + ) + + addr = tl.ravel(addr) + + tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ") + tl.static_assert(addr.shape[0] <= 32, "Wait on at most 32 barriers at a time. ") + + # Assume Triton always sets tid.y == tid.z == 0. + if op == "ld": + tl.inline_asm_elementwise( + f""" + {{ + .reg .u32 %tmp32_<3>; + .reg .pred %p<2>; + + mov.u32 %tmp32_0, %tid.x; + setp.lt.s32 %p1, %tmp32_0, $2; + + mov.u32 $0, 0; + // initialize tmp_0 to 0 + wait_block: + @%p1 ld.global.{sem}.{scope}.u32 $0, [$1]; + setp.ne.u32 %p0, $0, $3; + and.pred %p0, %p0, %p1; + @%p0 bra wait_block; + }} + """, + "=r, l, r, r", + [addr, addr.shape[0], expect], + dtype=addr.dtype.element_ty, + is_pure=False, + pack=1, + ) + elif op == "atomic_cas": + tl.inline_asm_elementwise( + f""" + {{ + .reg .u32 %tmp32_<3>; + .reg .pred %p<2>; + + mov.u32 %tmp32_0, %tid.x; + setp.lt.s32 %p1, %tmp32_0, $2; + + mov.u32 $0, 0; + // initialize tmp_0 to 0 + wait_block: + @%p1 atom.global.{sem}.{scope}.cas.b32 $0, [$1], $3, $4; + setp.ne.u32 %p0, $0, $3; + and.pred %p0, %p0, %p1; + @%p0 bra wait_block; + }} + """, + "=r, l, r, r, r", + [addr, addr.shape[0], expect, update], + dtype=addr.dtype.element_ty, + is_pure=False, + pack=1, + ) + else: + raise NotImplementedError( + f"Unsupported op '{op}' for wait signal on gmem barrier. " + ) + + if not skip_sync: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected index 473ef2af..746633c1 100644 --- a/test/test_signal_wait.expected +++ b/test/test_signal_wait.expected @@ -1,6 +1,37 @@ This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py. Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. +--- assertExpectedJournal(TestWait.test_global_sync) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_multi_bar_sync_kernel_kernel(signal_pad, signal_pad_stride_0, signal_pad_stride_1, N, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + for offset_1 in tl.range(0, N.to(tl.int32), step=_BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + helion.runtime.triton_send_signal(addr=signal_pad + (indices_1 * signal_pad_stride_0 + offset_0 * signal_pad_stride_1), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=True) + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + (offset_0 * signal_pad_stride_0 + indices_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + +def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor): + M, N = signal_pad.shape + assert M == N + _BLOCK_SIZE_1 = N + _gmem_multi_bar_sync_kernel_kernel[N,](signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return signal_pad + +def _gmem_multi_bar_sync_kernel_make_precompiler(signal_pad: torch.Tensor): + M, N = signal_pad.shape + assert M == N + _BLOCK_SIZE_1 = N + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_multi_bar_sync_kernel_kernel)(signal_pad, signal_pad.stride(0), signal_pad.stride(1), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + --- assertExpectedJournal(TestWait.test_signal_basic) from __future__ import annotations @@ -76,6 +107,33 @@ def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor): from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) +--- assertExpectedJournal(TestWait.test_signal_multiple_cas) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False) + +def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor): + n, = signal_pad.shape + _BLOCK_SIZE_0 = 4 + _gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return signal_pad + +def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor): + n, = signal_pad.shape + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + --- assertExpectedJournal(TestWait.test_wait_2d_tile) from __future__ import annotations @@ -144,3 +202,66 @@ def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor): from helion.runtime.precompile_shim import make_precompiler return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3) +--- assertExpectedJournal(TestWait.test_wait_multi_bar) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +import test.test_signal_wait as _source_module + +@triton.jit +def _gmem_wait_multi_bar_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + tile_id = offset_0 // _BLOCK_SIZE_0 + tl.store(out + tile_id * out_stride_0, tile_id, None) + +def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor): + N, = signal_pad.shape + n = 4 + out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE) + _BLOCK_SIZE_0 = 4 + _gmem_wait_multi_bar_kernel_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + +def _gmem_wait_multi_bar_kernel_make_precompiler(signal_pad: torch.Tensor): + N, = signal_pad.shape + n = 4 + out = torch.empty(n, dtype=torch.int32, device=_source_module.DEVICE) + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_wait_multi_bar_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestWait.test_wait_multi_bar_cas) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_wait_multi_bar_kernel_cas_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + helion.runtime.triton_wait_multiple_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, expect=1, update=2, sem='acquire', scope='gpu', op='atomic_cas', skip_sync=False) + +def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor): + N, = signal_pad.shape + n = 4 + _BLOCK_SIZE_0 = 4 + _gmem_wait_multi_bar_kernel_cas_kernel[triton.cdiv(N, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return signal_pad + +def _gmem_wait_multi_bar_kernel_cas_make_precompiler(signal_pad: torch.Tensor): + N, = signal_pad.shape + n = 4 + _BLOCK_SIZE_0 = 4 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_wait_multi_bar_kernel_cas_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index 720e1530..dbff6046 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -54,6 +54,50 @@ def wait_for_2d_tile_kernel( torch.testing.assert_close(result, x) self.assertExpectedJournal(code) + def test_wait_multi_bar(self): + @helion.kernel + def gmem_wait_multi_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (N,) = signal_pad.shape + n = hl.register_block_size(N) + out = torch.empty(n, dtype=torch.int32, device=DEVICE) + + for tile in hl.tile(N, block_size=n): + hl.wait(signal_pad, [tile], signal=1) + out[tile.id] = tile.id + + return out + + signal_pad = torch.ones(16, device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + gmem_wait_multi_bar_kernel, (signal_pad,), block_size=[4] + ) + torch.testing.assert_close( + result, torch.arange(4, device=DEVICE, dtype=torch.int32) + ) + self.maxDiff = None + self.assertExpectedJournal(code) + + def test_wait_multi_bar_cas(self): + @helion.kernel + def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor) -> torch.Tensor: + (N,) = signal_pad.shape + n = hl.register_block_size(N) + + for tile in hl.tile(N, block_size=n): + hl.wait(signal_pad, [tile], signal=1, update=2, op="atomic_cas") + + return signal_pad + + signal_pad = torch.ones(16, device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + gmem_wait_multi_bar_kernel_cas, (signal_pad,), block_size=[4] + ) + torch.testing.assert_close( + result, torch.full((16,), fill_value=2, device=DEVICE, dtype=torch.int32) + ) + self.maxDiff = None + self.assertExpectedJournal(code) + def test_signal_basic(self): @helion.kernel def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: @@ -103,7 +147,26 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: ) self.assertExpectedJournal(code) - def test_sent_recieve_cta(self): + def test_signal_multiple_cas(self): + @helion.kernel + def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + (n,) = signal_pad.shape + for tile in hl.tile(n): + hl.signal(signal_pad, [tile], wait_for=0, signal=1, op="atomic_cas") + return signal_pad + + signal_pad = torch.zeros(16, device=DEVICE, dtype=torch.int32) + code, result = code_and_output( + gmem_signal_tensor_bar_kernel, + (signal_pad,), + block_size=[4], + ) + torch.testing.assert_close( + result, torch.ones(16, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_send_recieve_cta(self): @helion.kernel def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: (n,) = signal_pad.shape @@ -122,6 +185,51 @@ def gmem_signal_n_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: self.assertIn("helion.runtime.triton_send_signal", code) self.assertIn("helion.runtime.triton_wait_signal", code) + def test_global_sync(self): + @helion.kernel + def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + M, N = signal_pad.shape + assert M == N + for i in hl.grid(N): + for tile in hl.tile(N, block_size=N): + hl.signal(signal_pad, [tile, i], signal=1, skip_sync=True) + hl.wait(signal_pad, [i, tile], signal=1) + return signal_pad + + signal_pad = torch.zeros(4, 4, device=DEVICE, dtype=torch.int32) + + code, result = code_and_output(gmem_multi_bar_sync_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.ones(4, 4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_global_sync_cas(self): + @helion.kernel + def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + M, N = signal_pad.shape + assert M == N + for i in hl.grid(N): + for tile in hl.tile(N, block_size=N): + hl.signal( + signal_pad, + [tile, i], + signal=1, + wait_for=0, + skip_sync=True, + op="atomic_cas", + ) + hl.wait(signal_pad, [i, tile], signal=1, update=2, op="atomic_cas") + return signal_pad + + signal_pad = torch.zeros(4, 4, device=DEVICE, dtype=torch.int32) + + code, result = code_and_output(gmem_multi_bar_sync_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.full((4, 4), fill_value=2, device=DEVICE, dtype=torch.int32) + ) + self.assertIn("atomic_cas", code) + if __name__ == "__main__": unittest.main()