Skip to content

Commit 6acea31

Browse files
committed
[BC breaking] Add MulticastTensor support to hl.signal & hl.wait (as_ptrs)
stack-info: PR: #261, branch: joydddd/stack/13
1 parent 7cc53a9 commit 6acea31

File tree

6 files changed

+323
-140
lines changed

6 files changed

+323
-140
lines changed

examples/all_gather_matmul.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ def helion_matmul_w_progress(
9696
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
9797
],
9898
signal=1,
99-
update=None,
100-
op="ld",
101-
scope="gpu",
102-
sem="acquire",
10399
)
104100
for tile_k in hl.tile(K):
105101
# TODO(joydddd): use a_shared and skip barrier when data is available on local rank.

helion/language/multicast_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from . import _decorators
99

1010
if TYPE_CHECKING:
11+
from typing import Sequence
12+
1113
from .._compiler.type_propagation import TypeInfo
1214
from .._compiler.variable_origin import Origin
1315

@@ -67,6 +69,11 @@ def __setitem__( # pyright ignore[reportIncompatibleMethodOverride]
6769
) -> None:
6870
raise exc.NotInsideKernel
6971

72+
def new_empty(
73+
self, *args: Sequence[int | torch.SymInt], **kwargs: dict
74+
) -> torch.Tensor:
75+
return self.tensor_like.new_empty(*args, **kwargs) # pyright: ignore[reportCallIssue]
76+
7077

7178
def multicast_like(
7279
tensor_like: torch.Tensor,

0 commit comments

Comments
 (0)