Skip to content

Commit dd129f4

Browse files
committed
[BC breaking] Add MulticastTensor support to hl.signal & hl.wait (as_ptrs)
stack-info: PR: #261, branch: joydddd/stack/13
1 parent fd02b59 commit dd129f4

File tree

6 files changed

+323
-140
lines changed

6 files changed

+323
-140
lines changed

examples/all_gather_matmul.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ def helion_matmul_w_progress(
9696
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
9797
],
9898
signal=1,
99-
update=None,
100-
op="ld",
101-
scope="gpu",
102-
sem="acquire",
10399
)
104100
for tile_k in hl.tile(K):
105101
# TODO(joydddd): use a_shared and skip barrier when data is available on local rank.

helion/language/multicast_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from . import _decorators
1010

1111
if TYPE_CHECKING:
12+
from typing import Sequence
13+
1214
from .._compiler.type_propagation import TypeInfo
1315
from .._compiler.variable_origin import Origin
1416

@@ -60,6 +62,11 @@ def __setitem__( # pyright ignore[reportIncompatibleMethodOverride]
6062
) -> None:
6163
raise exc.NotInsideKernel
6264

65+
def new_empty(
66+
self, *args: Sequence[int | torch.SymInt], **kwargs: dict
67+
) -> torch.Tensor:
68+
return self.tensor_like.new_empty(*args, **kwargs) # pyright: ignore[reportCallIssue]
69+
6370

6471
def multicast_like(
6572
tensor_like: torch.Tensor,

0 commit comments

Comments
 (0)