Skip to content

Add hl.wait & AllGather Matmul example (via hl_ext helper). #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions examples/all_gather_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations

import os
from typing import Any

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import helion
import helion.language as hl


def copy_engine_all_gather_w_progress(
output: torch.Tensor,
inp: torch.Tensor, # Must be symmetric tensor
progress: torch.Tensor,
splits_per_rank: int,
backend_stream: torch.cuda.Stream | None = None,
) -> torch.cuda.Stream:
backend_stream = symm_mem._get_backend_stream(priority=-1)
assert inp.is_contiguous()
symm_mem_group = dist.group.WORLD
if symm_mem_group is None:
raise RuntimeError("No symmetric memory group available")
symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group)
assert symm_mem_hdl is not None

rank = symm_mem_hdl.rank
world_size = symm_mem_hdl.world_size

assert inp.numel() % splits_per_rank == 0
assert progress.numel() >= world_size * splits_per_rank

output_shape = list(inp.shape)
output_shape[0] *= world_size
assert list(output.shape) == output_shape, (list(output.shape), output_shape)

chunks = output.chunk(world_size * splits_per_rank)

symm_mem_hdl.barrier()
backend_stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(backend_stream):
for step in range(world_size):
src_rank = (rank + step + 1) % world_size
for split_id in range(splits_per_rank):
src_buf = symm_mem_hdl.get_buffer(
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
)
chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
# cuStreamWriteValue32 issues a system level fence before the write
symm_mem_hdl.stream_write_value32(
progress,
offset=src_rank * splits_per_rank + split_id,
val=1,
)
symm_mem_hdl.barrier()

return backend_stream


# TODO(joydddd): add support for auto-tuning on multiple process runs.
# Please hardcode helion config for multiprocess runs initiated by torchrun.
@helion.jit(
config=helion.Config(
block_sizes=[128, 256, 64],
num_warps=8,
num_stages=3,
indexing="block_ptr",
),
static_shapes=True,
)
def helion_matmul_w_progress(
a: torch.Tensor,
a_shared: torch.Tensor,
b: torch.Tensor,
progress: torch.Tensor,
SPLITS_PER_RANK: int,
RANK: int,
) -> torch.Tensor:
M, K = a.size()
K2, N = b.size()
assert K2 == K, f"size mismatch {K2} != {K}"

out = torch.empty(
[M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
)

M_per_rank = a_shared.size(0)

for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
hl.wait(
progress,
[
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
],
signal=1,
update=None,
op="ld",
scope="gpu",
sem="acquire",
)
for tile_k in hl.tile(K):
# TODO(joydddd): use a_shared and skipp barrier when data is available on local rank.
# if tile_k.begin // M_per_rank == RANK:
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
# else:
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
out[tile_m, tile_n] = acc
return out


def helion_all_gather_matmul(
a_shared: torch.Tensor,
b: torch.Tensor,
a_out: torch.Tensor | None = None,
progress: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, torch.Tensor]:
configs = {
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
}

symm_mem_group = dist.group.WORLD
if symm_mem_group is None:
raise RuntimeError("No symmetric memory group available")

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group)

a_shape = list(a_shared.shape)
a_shape[0] *= symm_mem_hdl.world_size

configs["RANK"] = symm_mem_hdl.rank
configs["WORLD_SIZE"] = symm_mem_hdl.world_size

if a_out is None:
a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)

if progress is None:
progress = torch.zeros(
symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
dtype=torch.uint32,
device=a_shared.device,
)
else:
progress.fill_(
0
) # Reset progress to 0. Maybe we should reset inside the kernel using cas?

backend_stream = copy_engine_all_gather_w_progress(
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
)

c = helion_matmul_w_progress(
a_out,
a_shared,
b,
progress,
SPLITS_PER_RANK=configs["SPLITS_PER_RANK"],
RANK=configs["RANK"],
)
assert type(c) is torch.Tensor

torch.cuda.current_stream().wait_stream(backend_stream)

return a_out, c


def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
a_shared = symm_mem.empty(
M // world_size, K, dtype=torch.bfloat16, device=device
).normal_()
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T

a_out, c = helion_all_gather_matmul(a_shared, b)

golden_a = a_shared.clone()
dist_group = dist.group.WORLD
if dist_group is None:
raise RuntimeError("No distributed group available")
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
)
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
torch.testing.assert_close(a_out, ag_golden)


def main() -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl")
test(4096, 6656, 16384, world_size, device)

dist.destroy_process_group()


if __name__ == "__main__":
"""
Run with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_gather_matmul.py
"""
main()
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .scan_ops import associative_scan as associative_scan
from .scan_ops import cumprod as cumprod
from .scan_ops import cumsum as cumsum
from .signal_wait import wait as wait
from .tile_ops import tile_begin as tile_begin
from .tile_ops import tile_block_size as tile_block_size
from .tile_ops import tile_end as tile_end
Expand Down
148 changes: 148 additions & 0 deletions helion/language/signal_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch.fx import has_side_effect

from .. import exc
from . import _decorators

if TYPE_CHECKING:
import ast

from .._compiler.inductor_lowering import CodegenState


@has_side_effect
@_decorators.api(tiles_as_sizes=True)
def wait(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "gpu",
skip_sync: bool = False,
) -> None:
"""Wait until all entries of the signal_pad slice are equal to the signal value.
Args:
signal_pad: The signal pad tensor to wait on
index: Indices to index into the signal_pad tensor
signal: the value to wait for
update: Atomically update the signal_pad tensor with this value once the signal is observed. (default: None)
op: The memory op for acquring the lock (default: 'ld')
sem: The memory sematic for acquring the lock (default: 'acquire')
scope: The scope of the lock (default: 'gpu')
skip_sync: Skip the syncthreads after the wait (default: False)

Returns:
None
"""
raise exc.NotInsideKernel


@_decorators.prepare_args(wait)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "gpu",
skip_sync: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
from helion.language.tile_proxy import Tile

valid_ops = {"ld", "atomic_cas"}
valid_sems = {"relaxed", "acquire", "acq_rel"}
valid_scopes = {"sys", "gpu"}

if op not in valid_ops:
raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ")

if sem == "release":
raise ValueError(
f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}."
)

if sem not in valid_sems:
raise ValueError(
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
)

if op == "atomic_cas" and not update:
raise ValueError(
f"{op} without an update value. Do you want to use 'ld' instead? "
)

if op == "ld":
assert update is None
update = 0

if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

# TODO(joydddd): add support for non scalar index into signal_pad
for i in index:
assert isinstance(i, int | torch.SymInt)

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

return (signal_pad, index, signal, update, op, sem, scope, skip_sync)


@_decorators.register_fake(wait)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "sys",
skip_sync: bool = False,
) -> None:
return None


@_decorators.codegen(wait)
def _(state: CodegenState) -> ast.AST:
import ast

from .._compiler.ast_extension import expr_from_string
from .._compiler.indexing_strategy import SubscriptIndexing

signal_pad = state.proxy_arg(0)
index = state.proxy_arg(1)
signal = state.proxy_arg(2)
update = state.proxy_arg(3)
op = state.proxy_arg(4)
sem = state.proxy_arg(5)
scope = state.proxy_arg(6)
skip_sync = state.proxy_arg(7)

assert isinstance(signal_pad, torch.Tensor)
assert isinstance(index, (list))

indices = SubscriptIndexing.create(state, signal_pad, index)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name

signal_expr = ast.Constant(value=signal)
update_expr = ast.Constant(value=update)

assert type(op) is str
assert type(sem) is str
assert type(scope) is str

call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"

return expr_from_string(
call_triton_wait_signal,
offset=indices.index_expr,
signal=signal_expr,
update=update_expr,
)
1 change: 1 addition & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
from .triton_helpers import triton_wait_signal as triton_wait_signal


def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:
Expand Down
Loading
Loading