Skip to content

Commit 3044010

Browse files
committed
Add hl.wait & AllGather Matmul example (via hl_ext helper).
stack-info: PR: #189, branch: joydddd/stack/5
1 parent 902741b commit 3044010

File tree

7 files changed

+564
-0
lines changed

7 files changed

+564
-0
lines changed

examples/all_gather_matmul.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import Any
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.distributed._symmetric_memory as symm_mem
9+
10+
import helion
11+
import helion.language as hl
12+
13+
14+
def copy_engine_all_gather_w_progress(
15+
output: torch.Tensor,
16+
inp: torch.Tensor, # Must be symmetric tensor
17+
progress: torch.Tensor,
18+
splits_per_rank: int,
19+
backend_stream: torch.cuda.Stream | None = None,
20+
) -> torch.cuda.Stream:
21+
backend_stream = symm_mem._get_backend_stream(priority=-1)
22+
assert inp.is_contiguous()
23+
symm_mem_group = dist.group.WORLD
24+
if symm_mem_group is None:
25+
raise RuntimeError("No symmetric memory group available")
26+
symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group)
27+
assert symm_mem_hdl is not None
28+
29+
rank = symm_mem_hdl.rank
30+
world_size = symm_mem_hdl.world_size
31+
32+
assert inp.numel() % splits_per_rank == 0
33+
assert progress.numel() >= world_size * splits_per_rank
34+
35+
output_shape = list(inp.shape)
36+
output_shape[0] *= world_size
37+
assert list(output.shape) == output_shape, (list(output.shape), output_shape)
38+
39+
chunks = output.chunk(world_size * splits_per_rank)
40+
41+
symm_mem_hdl.barrier()
42+
backend_stream.wait_stream(torch.cuda.current_stream())
43+
44+
with torch.cuda.stream(backend_stream):
45+
for step in range(world_size):
46+
src_rank = (rank + step + 1) % world_size
47+
for split_id in range(splits_per_rank):
48+
src_buf = symm_mem_hdl.get_buffer(
49+
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
50+
)
51+
chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
52+
# cuStreamWriteValue32 issues a system level fence before the write
53+
symm_mem_hdl.stream_write_value32(
54+
progress,
55+
offset=src_rank * splits_per_rank + split_id,
56+
val=1,
57+
)
58+
symm_mem_hdl.barrier()
59+
60+
return backend_stream
61+
62+
63+
# TODO(joydddd): add support for auto-tuning on multiple process runs.
64+
# Please hardcode helion config for multiprocess runs initiated by torchrun.
65+
@helion.jit(
66+
config=helion.Config(
67+
block_sizes=[128, 256, 64],
68+
num_warps=8,
69+
num_stages=3,
70+
indexing="block_ptr",
71+
),
72+
static_shapes=True,
73+
)
74+
def helion_matmul_w_progress(
75+
a: torch.Tensor,
76+
a_shared: torch.Tensor,
77+
b: torch.Tensor,
78+
progress: torch.Tensor,
79+
SPLITS_PER_RANK: int,
80+
RANK: int,
81+
) -> torch.Tensor:
82+
M, K = a.size()
83+
K2, N = b.size()
84+
assert K2 == K, f"size mismatch {K2} != {K}"
85+
86+
out = torch.empty(
87+
[M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
88+
)
89+
90+
M_per_rank = a_shared.size(0)
91+
92+
for tile_m, tile_n in hl.tile([M, N]):
93+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
94+
hl.wait(
95+
progress,
96+
[
97+
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
98+
],
99+
signal=1,
100+
update=None,
101+
op="ld",
102+
scope="gpu",
103+
sem="acquire",
104+
)
105+
for tile_k in hl.tile(K):
106+
# TODO(joydddd): use a_shared and skipp barrier when data is available on local rank.
107+
# if tile_k.begin // M_per_rank == RANK:
108+
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
109+
# else:
110+
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
111+
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
112+
out[tile_m, tile_n] = acc
113+
return out
114+
115+
116+
def helion_all_gather_matmul(
117+
a_shared: torch.Tensor,
118+
b: torch.Tensor,
119+
a_out: torch.Tensor | None = None,
120+
progress: torch.Tensor | None = None,
121+
**kwargs: Any,
122+
) -> tuple[torch.Tensor, torch.Tensor]:
123+
configs = {
124+
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
125+
}
126+
127+
symm_mem_group = dist.group.WORLD
128+
if symm_mem_group is None:
129+
raise RuntimeError("No symmetric memory group available")
130+
131+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group)
132+
133+
a_shape = list(a_shared.shape)
134+
a_shape[0] *= symm_mem_hdl.world_size
135+
136+
configs["RANK"] = symm_mem_hdl.rank
137+
configs["WORLD_SIZE"] = symm_mem_hdl.world_size
138+
139+
if a_out is None:
140+
a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)
141+
142+
if progress is None:
143+
progress = torch.zeros(
144+
symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
145+
dtype=torch.uint32,
146+
device=a_shared.device,
147+
)
148+
else:
149+
progress.fill_(
150+
0
151+
) # Reset progress to 0. Maybe we should reset inside the kernel using cas?
152+
153+
backend_stream = copy_engine_all_gather_w_progress(
154+
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
155+
)
156+
157+
c = helion_matmul_w_progress(
158+
a_out,
159+
a_shared,
160+
b,
161+
progress,
162+
SPLITS_PER_RANK=configs["SPLITS_PER_RANK"],
163+
RANK=configs["RANK"],
164+
)
165+
assert type(c) is torch.Tensor
166+
167+
torch.cuda.current_stream().wait_stream(backend_stream)
168+
169+
return a_out, c
170+
171+
172+
def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
173+
a_shared = symm_mem.empty(
174+
M // world_size, K, dtype=torch.bfloat16, device=device
175+
).normal_()
176+
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T
177+
178+
a_out, c = helion_all_gather_matmul(a_shared, b)
179+
180+
golden_a = a_shared.clone()
181+
dist_group = dist.group.WORLD
182+
if dist_group is None:
183+
raise RuntimeError("No distributed group available")
184+
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
185+
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
186+
)
187+
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
188+
torch.testing.assert_close(a_out, ag_golden)
189+
190+
191+
def main() -> None:
192+
rank = int(os.environ["LOCAL_RANK"])
193+
world_size = int(os.environ["WORLD_SIZE"])
194+
torch.manual_seed(42 + rank)
195+
device = torch.device(f"cuda:{rank}")
196+
torch.cuda.set_device(device)
197+
dist.init_process_group("nccl")
198+
test(4096, 6656, 16384, world_size, device)
199+
200+
dist.destroy_process_group()
201+
202+
203+
if __name__ == "__main__":
204+
"""
205+
Run with:
206+
torchrun \
207+
--nnodes 1 --nproc-per-node 8 \
208+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
209+
--no_python python3 examples/all_gather_matmul.py
210+
"""
211+
main()

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scan_ops import associative_scan as associative_scan
1616
from .scan_ops import cumprod as cumprod
1717
from .scan_ops import cumsum as cumsum
18+
from .signal_wait import wait as wait
1819
from .tile_ops import tile_begin as tile_begin
1920
from .tile_ops import tile_block_size as tile_block_size
2021
from .tile_ops import tile_end as tile_end

helion/language/signal_wait.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
from torch.fx import has_side_effect
7+
8+
from .. import exc
9+
from . import _decorators
10+
11+
if TYPE_CHECKING:
12+
import ast
13+
14+
from .._compiler.inductor_lowering import CodegenState
15+
16+
17+
@has_side_effect
18+
@_decorators.api(tiles_as_sizes=True)
19+
def wait(
20+
signal_pad: torch.Tensor,
21+
index: list[object],
22+
signal: int = 1,
23+
update: int | None = None,
24+
op: str = "ld",
25+
sem: str = "acquire",
26+
scope: str = "gpu",
27+
skip_sync: bool = False,
28+
) -> None:
29+
"""Wait until all entries of the signal_pad slice are equal to the signal value.
30+
Args:
31+
signal_pad: The signal pad tensor to wait on
32+
index: Indices to index into the signal_pad tensor
33+
signal: the value to wait for
34+
update: Atomically update the signal_pad tensor with this value once the signal is observed. (default: None)
35+
op: The memory op for acquring the lock (default: 'ld')
36+
sem: The memory sematic for acquring the lock (default: 'acquire')
37+
scope: The scope of the lock (default: 'gpu')
38+
skip_sync: Skip the syncthreads after the wait (default: False)
39+
40+
Returns:
41+
None
42+
"""
43+
raise exc.NotInsideKernel
44+
45+
46+
@_decorators.prepare_args(wait)
47+
def _(
48+
signal_pad: torch.Tensor,
49+
index: list[object],
50+
signal: int = 1,
51+
update: int | None = None,
52+
op: str = "ld",
53+
sem: str = "acquire",
54+
scope: str = "gpu",
55+
skip_sync: bool = False,
56+
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
57+
from helion.language.tile_proxy import Tile
58+
59+
valid_ops = {"ld", "atomic_cas"}
60+
valid_sems = {"relaxed", "acquire", "acq_rel"}
61+
valid_scopes = {"sys", "gpu"}
62+
63+
if op not in valid_ops:
64+
raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ")
65+
66+
if sem == "release":
67+
raise ValueError(
68+
f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}."
69+
)
70+
71+
if sem not in valid_sems:
72+
raise ValueError(
73+
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
74+
)
75+
76+
if op == "atomic_cas" and not update:
77+
raise ValueError(
78+
f"{op} without an update value. Do you want to use 'ld' instead? "
79+
)
80+
81+
if op == "ld":
82+
assert update is None
83+
update = 0
84+
85+
if scope not in valid_scopes:
86+
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")
87+
88+
# TODO(joydddd): add support for non scalar index into signal_pad
89+
for i in index:
90+
assert isinstance(i, int | torch.SymInt)
91+
92+
index = Tile._prepare_index(index)
93+
index = Tile._tiles_to_sizes(index)
94+
95+
return (signal_pad, index, signal, update, op, sem, scope, skip_sync)
96+
97+
98+
@_decorators.register_fake(wait)
99+
def _(
100+
signal_pad: torch.Tensor,
101+
index: list[object],
102+
signal: int = 1,
103+
update: int | None = None,
104+
op: str = "ld",
105+
sem: str = "acquire",
106+
scope: str = "sys",
107+
skip_sync: bool = False,
108+
) -> None:
109+
return None
110+
111+
112+
@_decorators.codegen(wait)
113+
def _(state: CodegenState) -> ast.AST:
114+
import ast
115+
116+
from .._compiler.ast_extension import expr_from_string
117+
from .._compiler.indexing_strategy import SubscriptIndexing
118+
119+
signal_pad = state.proxy_arg(0)
120+
index = state.proxy_arg(1)
121+
signal = state.proxy_arg(2)
122+
update = state.proxy_arg(3)
123+
op = state.proxy_arg(4)
124+
sem = state.proxy_arg(5)
125+
scope = state.proxy_arg(6)
126+
skip_sync = state.proxy_arg(7)
127+
128+
assert isinstance(signal_pad, torch.Tensor)
129+
assert isinstance(index, (list))
130+
131+
indices = SubscriptIndexing.create(state, signal_pad, index)
132+
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
133+
134+
signal_expr = ast.Constant(value=signal)
135+
update_expr = ast.Constant(value=update)
136+
137+
assert type(op) is str
138+
assert type(sem) is str
139+
assert type(scope) is str
140+
141+
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
142+
143+
return expr_from_string(
144+
call_triton_wait_signal,
145+
offset=indices.index_expr,
146+
signal=signal_expr,
147+
update=update_expr,
148+
)

helion/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .config import Config as Config
99
from .kernel import Kernel as Kernel
1010
from .kernel import kernel as kernel
11+
from .triton_helpers import triton_wait_signal as triton_wait_signal
1112

1213

1314
def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor:

0 commit comments

Comments
 (0)