Skip to content

Commit fd02b59

Browse files
committed
Add multicast tensor
stack-info: PR: #346, branch: joydddd/stack/17
1 parent 41fe6e9 commit fd02b59

File tree

10 files changed

+856
-21
lines changed

10 files changed

+856
-21
lines changed

helion/_compiler/device_ir.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .type_propagation import GridIndexType
5353
from .type_propagation import IterType
5454
from .type_propagation import LiteralType
55+
from .type_propagation import MulticastTensorType
5556
from .type_propagation import NumericType
5657
from .type_propagation import SequenceType
5758
from .type_propagation import TensorType
@@ -330,6 +331,14 @@ def build_rolled_reductions(self) -> None:
330331
can_roll_graphs = False
331332
break
332333

334+
# Check if any graph contains tensor multicasted along rdim
335+
# If so, we can't roll any graphs in this reduction dimemsion.
336+
for graph_info in self.graphs:
337+
roller = ReductionRoller(self, rdim, {})
338+
if roller.has_multicast_tensor_with_rdim(graph_info.graph):
339+
can_roll_graphs = False
340+
break
341+
333342
if not can_roll_graphs:
334343
first = False
335344
continue
@@ -783,7 +792,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
783792
assert isinstance(target.value, ExtendedAST)
784793
assert target.value._type_info is not None
785794
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
786-
if not target_origin.is_host():
795+
if not target_origin.is_host() and not isinstance(
796+
target.value._type_info, MulticastTensorType
797+
):
787798
# Get the variable name for the error message
788799
var_name = (
789800
target.value.id
@@ -808,7 +819,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
808819
assert isinstance(target.value, ExtendedAST)
809820
assert target.value._type_info is not None
810821
target_origin = target.value._type_info.origin
811-
assert target_origin.is_host()
822+
assert target_origin.is_host() or isinstance(
823+
target.value._type_info, MulticastTensorType
824+
)
812825

813826
return hl.store(
814827
self.visit(target.value), # pyright: ignore[reportArgumentType]
@@ -841,6 +854,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
841854
if isinstance(node.slice, ast.Constant):
842855
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
843856
raise exc.InvalidSequenceSubscription(node.slice)
857+
if type_info is not None and isinstance(type_info, MulticastTensorType):
858+
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
844859
if type_info is not None and type_info.origin.is_host():
845860
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
846861
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/indexing_strategy.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import sympy
1010
import torch
11+
from torch._inductor.utils import triton_type
1112
import triton
1213

1314
from .. import exc
@@ -20,10 +21,15 @@
2021
from .variable_origin import BlockSizeOrigin
2122

2223
if TYPE_CHECKING:
24+
from collections.abc import Sequence
25+
2326
from ..runtime.config import Config
2427
from .device_function import TensorDescriptorArg
2528
from .inductor_lowering import CodegenState
2629

30+
SymIntLike = torch.SymInt | int
31+
ShapeLike = Sequence[SymIntLike]
32+
2733

2834
class IndexingStrategy:
2935
def codegen_load(
@@ -289,6 +295,134 @@ def codegen_store(
289295
)
290296

291297

298+
class MulticastIndexingStrategy:
299+
@staticmethod
300+
def get_broadcast_str(
301+
multicast_shape: ShapeLike,
302+
subscript_shape: ShapeLike,
303+
) -> tuple[str, str]:
304+
multicast_broadcast_keys = [":" for _ in multicast_shape] + [
305+
"None" for _ in subscript_shape
306+
]
307+
multicast_broadcast = f"[{', '.join(multicast_broadcast_keys)}]"
308+
tensor_broadcast_keys = ["None" for _ in multicast_shape] + [
309+
":" for _ in subscript_shape
310+
]
311+
tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]"
312+
313+
return multicast_broadcast, tensor_broadcast
314+
315+
@staticmethod
316+
def get_mask_expr(
317+
state: CodegenState,
318+
indexing: SubscriptIndexing,
319+
multicast_shape: ShapeLike,
320+
subscript_shape: ShapeLike,
321+
) -> ast.AST | None:
322+
multicast_broadcast, tensor_broadcast = (
323+
MulticastIndexingStrategy.get_broadcast_str(
324+
multicast_shape, subscript_shape
325+
)
326+
)
327+
328+
mask_exprs = []
329+
dev_ptr_mask_exprs = []
330+
# Generate Mask
331+
332+
for dim, size in enumerate(multicast_shape):
333+
if (
334+
index := CompileEnvironment.current().get_block_id(size)
335+
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
336+
expand = state.tile_strategy.expand_str(multicast_shape, dim)
337+
dev_ptr_mask_exprs.append(f"({mask_var}{expand})")
338+
339+
if dev_ptr_mask_exprs:
340+
dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})"
341+
if len(dev_ptr_mask_exprs) < len(multicast_shape):
342+
dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(multicast_shape)})"
343+
dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){multicast_broadcast}"
344+
mask_exprs.append(dev_ptr_mask_expr)
345+
346+
if indexing.has_mask():
347+
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
348+
return expr_from_string(
349+
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
350+
)
351+
if mask_exprs:
352+
return expr_from_string("&".join(mask_exprs))
353+
return None
354+
355+
@staticmethod
356+
def codegen_load(
357+
state: CodegenState,
358+
tensors: tuple[torch.Tensor, torch.Tensor],
359+
dev_ptrs_ast: ast.AST,
360+
subscript: list[object],
361+
extra_mask: ast.AST | None,
362+
) -> ast.AST:
363+
tensor_like, dev_ptrs = tensors
364+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
365+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
366+
multicast_shape = [*dev_ptrs.size()]
367+
368+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
369+
state, indexing, multicast_shape, subscripts_shape
370+
)
371+
extra = ", other=0"
372+
if mask_expr is None:
373+
mask_expr = expr_from_string("None")
374+
extra = ""
375+
376+
multicast_broadcast, tensor_broadcast = (
377+
MulticastIndexingStrategy.get_broadcast_str(
378+
multicast_shape, subscripts_shape
379+
)
380+
)
381+
382+
dtype = triton_type(tensor_like.dtype)
383+
return expr_from_string(
384+
f"tl.load((base.to(tl.pointer_type({dtype}))){multicast_broadcast} + (offset){tensor_broadcast}, mask{extra})",
385+
base=dev_ptrs_ast,
386+
offset=indexing.index_expr,
387+
mask=mask_expr,
388+
)
389+
390+
@staticmethod
391+
def codegen_store(
392+
state: CodegenState,
393+
tensors: tuple[torch.Tensor, torch.Tensor],
394+
dev_ptrs_ast: ast.AST,
395+
subscript: list[object],
396+
value: ast.AST,
397+
extra_mask: ast.AST | None,
398+
) -> ast.AST:
399+
tensor_like, dev_ptrs = tensors
400+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
401+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
402+
multicast_shape = [*dev_ptrs.size()]
403+
404+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
405+
state, indexing, multicast_shape, subscripts_shape
406+
)
407+
if mask_expr is None:
408+
mask_expr = expr_from_string("None")
409+
410+
multicast_broadcast, tensor_broadcast = (
411+
MulticastIndexingStrategy.get_broadcast_str(
412+
multicast_shape, subscripts_shape
413+
)
414+
)
415+
416+
dtype = triton_type(tensor_like.dtype)
417+
return expr_from_string(
418+
f"tl.store(base.to(tl.pointer_type({dtype})){multicast_broadcast} + (offset){tensor_broadcast}, value, mask)",
419+
base=dev_ptrs_ast,
420+
value=value,
421+
offset=indexing.index_expr,
422+
mask=mask_expr,
423+
)
424+
425+
292426
class SubscriptIndexing(NamedTuple):
293427
index_expr: ast.AST
294428
mask_expr: ast.AST

helion/_compiler/roll_reduction.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.fx import map_arg
88

9+
from ..language import _MEMORY_OPS
910
from ..language._tracing_ops import _for_loop
1011
from ..language._tracing_ops import _get_symnode
1112
from ..language._tracing_ops import _host_tensor
@@ -277,6 +278,35 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool:
277278

278279
return any(is_matmul_with_rdim(node) for node in graph.nodes)
279280

281+
def has_multicast_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool:
282+
"""Check if a graph contains multicast tensors with rdim inputs."""
283+
284+
def is_multicast_with_rdim(node: torch.fx.Node) -> bool:
285+
"""Check if a node is a multicast dev_ptr with rdim inputs."""
286+
if node.op != "call_function":
287+
return False
288+
289+
if node.target not in _MEMORY_OPS:
290+
return False
291+
292+
host_tensor = node.args[0]
293+
294+
if not isinstance(host_tensor, tuple):
295+
return False
296+
297+
# Check if multicast dims have rdim
298+
if len(host_tensor) == 2:
299+
assert isinstance(host_tensor[1], torch.fx.Node)
300+
multicast = host_tensor[1].meta.get("val", None)
301+
if isinstance(multicast, torch.Tensor):
302+
for size in multicast.size():
303+
block_idx = CompileEnvironment.current().get_block_id(size)
304+
if block_idx == self.rdim.block_id:
305+
return True
306+
return False
307+
308+
return any(is_multicast_with_rdim(node) for node in graph.nodes)
309+
280310
def process(self, graph: torch.fx.Graph) -> torch.fx.Graph:
281311
for node in graph.nodes:
282312
if self.should_go_in_inner_graph(node):

helion/_compiler/type_propagation.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..autotuner.config_spec import BlockSizeSpec
2828
from ..language._decorators import get_device_func_replacement
2929
from ..language._decorators import is_api_func
30+
from ..language.multicast_tensor import MulticastTensor
3031
from ..language.tile_proxy import Tile
3132
from ..language.tile_proxy import _CheckForIndexCalls
3233
from .ast_extension import ExtendedAST
@@ -1289,6 +1290,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
12891290
return self.element_types[attr]
12901291

12911292

1293+
class MulticastTensorType(ClassType):
1294+
element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride]
1295+
1296+
def proxy(self) -> MulticastTensor: # pyright: ignore[reportIncompatibleMethodOverride]
1297+
with proxy_tensor.disable_proxy_modes_tracing():
1298+
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
1299+
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
1300+
)
1301+
try:
1302+
assert isinstance(self.element_types["tensor_like"], TensorType)
1303+
assert isinstance(self.element_types["dev_ptrs"], TensorType)
1304+
return MulticastTensor(
1305+
self.element_types["tensor_like"].proxy(),
1306+
self.element_types["dev_ptrs"].proxy(),
1307+
)
1308+
finally:
1309+
assert fake_mode is not None
1310+
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
1311+
1312+
def merge(self, other: TypeInfo) -> TypeInfo:
1313+
if isinstance(other, MulticastTensorType):
1314+
self_elements = self.element_types
1315+
other_elements = other.element_types
1316+
if set(self_elements.keys()) == set(other_elements.keys()):
1317+
return MulticastTensorType(
1318+
origin=other.origin,
1319+
element_types={
1320+
key: self_elements[key].merge(other_elements[key])
1321+
for key in self_elements
1322+
},
1323+
)
1324+
return super().merge(other)
1325+
1326+
def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
1327+
tensor_like_type = self.element_types["tensor_like"]
1328+
assert isinstance(tensor_like_type, TensorType)
1329+
size_like = tensor_like_type._device_indexing_size(key)
1330+
1331+
dev_ptrs_type = self.element_types["dev_ptrs"]
1332+
assert isinstance(dev_ptrs_type, TensorType)
1333+
multicast_size = list(dev_ptrs_type.fake_value.size())
1334+
1335+
return multicast_size + size_like
1336+
1337+
def propagate_setitem(
1338+
self, key: TypeInfo, value: TypeInfo, origin: Origin
1339+
) -> TypeInfo:
1340+
if origin.is_host():
1341+
warning(exc.TensorOperationInWrapper)
1342+
else:
1343+
lhs_shape = self._device_indexing_size(key)
1344+
lhs_rank = len(lhs_shape)
1345+
if isinstance(value, TensorType):
1346+
rhs_rank = value.fake_value.ndim
1347+
if lhs_rank != rhs_rank:
1348+
raise exc.RankMismatch(
1349+
lhs_rank,
1350+
rhs_rank,
1351+
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
1352+
)
1353+
elif isinstance(value, (NumericType, LiteralType)):
1354+
# Allow scalar assignment to tensor (broadcasts to tensor shape)
1355+
pass
1356+
else:
1357+
raise exc.RequiresTensorInAssignment(value)
1358+
return self
1359+
1360+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1361+
if origin.is_host():
1362+
warning(exc.TensorOperationInWrapper)
1363+
1364+
assert isinstance(self.element_types["tensor_like"], TensorType)
1365+
return TensorType(
1366+
origin,
1367+
self.element_types["tensor_like"]
1368+
.proxy()
1369+
.new_empty(self._device_indexing_size(key)),
1370+
)
1371+
1372+
12921373
class SliceType(CollectionType):
12931374
element_types: slice # pyright: ignore[reportIncompatibleVariableOverride]
12941375

@@ -1614,7 +1695,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
16141695
if isinstance(lhs, ast.Subscript):
16151696
# TODO(jansel): test different types of subscript
16161697
lhs_base_type = self.visit(lhs.value)
1617-
if isinstance(lhs_base_type, TensorType):
1698+
if isinstance(lhs_base_type, (TensorType, MulticastTensorType)):
16181699
self.visit(lhs) # need to populate shape info
16191700
lhs_base_type = lhs_base_type.propagate_setitem(
16201701
self.visit(lhs.slice), rhs, self.origin()

helion/exc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,24 @@ class SpecializeArgType(BaseError):
138138
message = "hl.specialize() must be called on a size from an input tensor, got: {}"
139139

140140

141+
class MulticastTensorcOnHost(BaseError):
142+
message = (
143+
"hl.multicast_tensor must be called inside the `hl.tile` or `hl.grid` loop."
144+
)
145+
146+
147+
class MulticastTensorDevPtrOnHost(BaseError):
148+
message = "hl.multicast_tensor must be called with a dev_ptr tensor defined on device. Use `hl.load` to load a dev_ptrs tensor. "
149+
150+
151+
class MulticastTensorDevPtrDtype(BaseError):
152+
message = "hl.multicast_tensor must be called with a dev_ptr tensor of dtype int64. Got: {0!s}"
153+
154+
155+
class MulticastTensorExampleOnDevice(BaseError):
156+
message = "hl.multicast_tensor must be called with an example host tensor."
157+
158+
141159
class FailedToUnpackTupleAssign(BaseError):
142160
message = "Failed to unpack values in tuple assignment. Expected a sequence of size {0}, got type: {1!s}."
143161

helion/language/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .memory_ops import atomic_add as atomic_add
1414
from .memory_ops import load as load
1515
from .memory_ops import store as store
16+
from .multicast_tensor import multicast_like as multicast_like
1617
from .reduce_ops import reduce as reduce
1718
from .scan_ops import associative_scan as associative_scan
1819
from .scan_ops import cumprod as cumprod
@@ -29,3 +30,5 @@
2930
from .tunable_ops import register_reduction_dim as register_reduction_dim
3031
from .tunable_ops import register_tunable as register_tunable
3132
from .view_ops import subscript as subscript
33+
34+
_MEMORY_OPS = [store, load, atomic_add, wait, signal]

0 commit comments

Comments
 (0)