Skip to content

Commit 7cc53a9

Browse files
committed
Add multicast tensor
stack-info: PR: #346, branch: joydddd/stack/17
1 parent 41fe6e9 commit 7cc53a9

File tree

10 files changed

+876
-23
lines changed

10 files changed

+876
-23
lines changed

helion/_compiler/device_ir.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from .type_propagation import GridIndexType
5353
from .type_propagation import IterType
5454
from .type_propagation import LiteralType
55+
from .type_propagation import MulticastTensorType
5556
from .type_propagation import NumericType
5657
from .type_propagation import SequenceType
5758
from .type_propagation import TensorType
@@ -321,12 +322,14 @@ def build_rolled_reductions(self) -> None:
321322
graph_to_info = {}
322323
allow_loop = False
323324

324-
# First, check if any graph contains matmul with rdim
325+
# First, check if any graph contains matmul or dev_prts multicast with rdim
325326
# If so, we can't roll any graphs in this reduction dimension
326327
can_roll_graphs = True
327328
for graph_info in self.graphs:
328329
roller = ReductionRoller(self, rdim, {})
329-
if roller.has_matmul_with_rdim(graph_info.graph):
330+
if roller.has_matmul_with_rdim(
331+
graph_info.graph
332+
) or roller.has_multicast_tensor_with_rdim(graph_info.graph):
330333
can_roll_graphs = False
331334
break
332335

@@ -783,7 +786,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
783786
assert isinstance(target.value, ExtendedAST)
784787
assert target.value._type_info is not None
785788
target_origin = target.value._type_info.origin # pyright: ignore[reportOptionalMemberAccess]
786-
if not target_origin.is_host():
789+
if not target_origin.is_host() and not isinstance(
790+
target.value._type_info, MulticastTensorType
791+
):
787792
# Get the variable name for the error message
788793
var_name = (
789794
target.value.id
@@ -808,7 +813,9 @@ def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
808813
assert isinstance(target.value, ExtendedAST)
809814
assert target.value._type_info is not None
810815
target_origin = target.value._type_info.origin
811-
assert target_origin.is_host()
816+
assert target_origin.is_host() or isinstance(
817+
target.value._type_info, MulticastTensorType
818+
)
812819

813820
return hl.store(
814821
self.visit(target.value), # pyright: ignore[reportArgumentType]
@@ -841,6 +848,8 @@ def visit_Subscript(self, node: ast.Subscript) -> object:
841848
if isinstance(node.slice, ast.Constant):
842849
return self.visit(value)[self.visit(node.slice)] # pyright: ignore[reportIndexIssue]
843850
raise exc.InvalidSequenceSubscription(node.slice)
851+
if isinstance(type_info, MulticastTensorType):
852+
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
844853
if type_info is not None and type_info.origin.is_host():
845854
return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]
846855
return hl.subscript(self.visit(value), self._subscript_slice_proxy(node.slice)) # pyright: ignore[reportArgumentType]

helion/_compiler/indexing_strategy.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import sympy
1010
import torch
11+
from torch._inductor.utils import triton_type
1112
import triton
1213

1314
from .. import exc
@@ -20,10 +21,15 @@
2021
from .variable_origin import BlockSizeOrigin
2122

2223
if TYPE_CHECKING:
24+
from collections.abc import Sequence
25+
2326
from ..runtime.config import Config
2427
from .device_function import TensorDescriptorArg
2528
from .inductor_lowering import CodegenState
2629

30+
SymIntLike = torch.SymInt | int
31+
ShapeLike = Sequence[SymIntLike]
32+
2733

2834
class IndexingStrategy:
2935
def codegen_load(
@@ -289,6 +295,153 @@ def codegen_store(
289295
)
290296

291297

298+
class MulticastIndexingStrategy:
299+
"""
300+
Generate pointer math for multicasting load/store to several device memory pointers sharing the same indexing.
301+
302+
offset, mask are calculated for the tensor_like template tensor and then broadcasted to each dev_ptr
303+
, with the results stacked.
304+
305+
e.g. for a 1D offset tensor and a 1D dev_ptr array, the multicasted offset is:
306+
multicast_offset = dev_ptrs[:, None] + offset[None, :]
307+
308+
"""
309+
310+
@staticmethod
311+
def get_broadcast_str(
312+
multicast_shape: ShapeLike,
313+
subscript_shape: ShapeLike,
314+
) -> tuple[str, str]:
315+
"""
316+
Args:
317+
multicast_shape: shape of the dev_ptr tensor.
318+
subscript_shape: shape of subscription for each individual tensor.
319+
320+
Returns:
321+
the broadcast str for dev_ptrs and individual tensor offset.
322+
"""
323+
multicast_broadcast_keys = [":" for _ in multicast_shape] + [
324+
"None" for _ in subscript_shape
325+
]
326+
multicast_broadcast = f"[{', '.join(multicast_broadcast_keys)}]"
327+
tensor_broadcast_keys = ["None" for _ in multicast_shape] + [
328+
":" for _ in subscript_shape
329+
]
330+
tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]"
331+
332+
return multicast_broadcast, tensor_broadcast
333+
334+
@staticmethod
335+
def get_mask_expr(
336+
state: CodegenState,
337+
indexing: SubscriptIndexing,
338+
multicast_shape: ShapeLike,
339+
subscript_shape: ShapeLike,
340+
) -> ast.AST | None:
341+
multicast_broadcast, tensor_broadcast = (
342+
MulticastIndexingStrategy.get_broadcast_str(
343+
multicast_shape, subscript_shape
344+
)
345+
)
346+
347+
mask_exprs = []
348+
dev_ptr_mask_exprs = []
349+
# Generate Mask
350+
351+
for dim, size in enumerate(multicast_shape):
352+
if (
353+
index := CompileEnvironment.current().get_block_id(size)
354+
) is not None and (mask_var := state.codegen.mask_var(index)) is not None:
355+
expand = state.tile_strategy.expand_str(multicast_shape, dim)
356+
dev_ptr_mask_exprs.append(f"({mask_var}{expand})")
357+
358+
if dev_ptr_mask_exprs:
359+
dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})"
360+
if len(dev_ptr_mask_exprs) < len(multicast_shape):
361+
dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(multicast_shape)})"
362+
dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){multicast_broadcast}"
363+
mask_exprs.append(dev_ptr_mask_expr)
364+
365+
if indexing.has_mask():
366+
mask_exprs.append(f"(tensor_mask){tensor_broadcast}")
367+
return expr_from_string(
368+
"&".join(mask_exprs), tensor_mask=indexing.mask_expr
369+
)
370+
if mask_exprs:
371+
return expr_from_string("&".join(mask_exprs))
372+
return None
373+
374+
@staticmethod
375+
def codegen_load(
376+
state: CodegenState,
377+
multicast_tensor: tuple[torch.Tensor, torch.Tensor],
378+
dev_ptrs_ast: ast.AST,
379+
subscript: list[object],
380+
extra_mask: ast.AST | None,
381+
) -> ast.AST:
382+
tensor_like, dev_ptrs = multicast_tensor
383+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
384+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
385+
multicast_shape = [*dev_ptrs.size()]
386+
387+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
388+
state, indexing, multicast_shape, subscripts_shape
389+
)
390+
extra = ", other=0"
391+
if mask_expr is None:
392+
mask_expr = expr_from_string("None")
393+
extra = ""
394+
395+
multicast_broadcast, tensor_broadcast = (
396+
MulticastIndexingStrategy.get_broadcast_str(
397+
multicast_shape, subscripts_shape
398+
)
399+
)
400+
401+
dtype = triton_type(tensor_like.dtype)
402+
return expr_from_string(
403+
f"tl.load((base.to(tl.pointer_type({dtype}))){multicast_broadcast} + (offset){tensor_broadcast}, mask{extra})",
404+
base=dev_ptrs_ast,
405+
offset=indexing.index_expr,
406+
mask=mask_expr,
407+
)
408+
409+
@staticmethod
410+
def codegen_store(
411+
state: CodegenState,
412+
multicast_tensor: tuple[torch.Tensor, torch.Tensor],
413+
dev_ptrs_ast: ast.AST,
414+
subscript: list[object],
415+
value: ast.AST,
416+
extra_mask: ast.AST | None,
417+
) -> ast.AST:
418+
tensor_like, dev_ptrs = multicast_tensor
419+
indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask)
420+
subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript)
421+
multicast_shape = [*dev_ptrs.size()]
422+
423+
mask_expr = MulticastIndexingStrategy.get_mask_expr(
424+
state, indexing, multicast_shape, subscripts_shape
425+
)
426+
if mask_expr is None:
427+
mask_expr = expr_from_string("None")
428+
429+
multicast_broadcast, tensor_broadcast = (
430+
MulticastIndexingStrategy.get_broadcast_str(
431+
multicast_shape, subscripts_shape
432+
)
433+
)
434+
435+
dtype = triton_type(tensor_like.dtype)
436+
return expr_from_string(
437+
f"tl.store(base.to(tl.pointer_type({dtype})){multicast_broadcast} + (offset){tensor_broadcast}, value, mask)",
438+
base=dev_ptrs_ast,
439+
value=value,
440+
offset=indexing.index_expr,
441+
mask=mask_expr,
442+
)
443+
444+
292445
class SubscriptIndexing(NamedTuple):
293446
index_expr: ast.AST
294447
mask_expr: ast.AST

helion/_compiler/roll_reduction.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.fx import map_arg
88

9+
from ..language import _MEMORY_OPS
910
from ..language._tracing_ops import _for_loop
1011
from ..language._tracing_ops import _get_symnode
1112
from ..language._tracing_ops import _host_tensor
@@ -277,6 +278,35 @@ def is_matmul_with_rdim(node: torch.fx.Node) -> bool:
277278

278279
return any(is_matmul_with_rdim(node) for node in graph.nodes)
279280

281+
def has_multicast_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool:
282+
"""Check if a graph contains multicast tensors with rdim inputs."""
283+
284+
def is_multicast_with_rdim(node: torch.fx.Node) -> bool:
285+
"""Check if a node is a multicast dev_ptr with rdim inputs."""
286+
if node.op != "call_function":
287+
return False
288+
289+
if node.target not in _MEMORY_OPS:
290+
return False
291+
292+
host_tensor = node.args[0]
293+
294+
if not isinstance(host_tensor, tuple):
295+
return False
296+
297+
# Check if multicast dims have rdim
298+
if len(host_tensor) == 2:
299+
assert isinstance(host_tensor[1], torch.fx.Node)
300+
multicast = host_tensor[1].meta.get("val", None)
301+
if isinstance(multicast, torch.Tensor):
302+
for size in multicast.size():
303+
block_idx = CompileEnvironment.current().get_block_id(size)
304+
if block_idx == self.rdim.block_id:
305+
return True
306+
return False
307+
308+
return any(is_multicast_with_rdim(node) for node in graph.nodes)
309+
280310
def process(self, graph: torch.fx.Graph) -> torch.fx.Graph:
281311
for node in graph.nodes:
282312
if self.should_go_in_inner_graph(node):

helion/_compiler/type_propagation.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ..autotuner.config_spec import BlockSizeSpec
2828
from ..language._decorators import get_device_func_replacement
2929
from ..language._decorators import is_api_func
30+
from ..language.multicast_tensor import MulticastTensor
3031
from ..language.tile_proxy import Tile
3132
from ..language.tile_proxy import _CheckForIndexCalls
3233
from .ast_extension import ExtendedAST
@@ -1289,6 +1290,86 @@ def propagate_attribute(self, attr: str, origin: AttributeOrigin) -> TypeInfo:
12891290
return self.element_types[attr]
12901291

12911292

1293+
class MulticastTensorType(ClassType):
1294+
element_types: dict[str, TypeInfo] # pyright: ignore[reportIncompatibleVariableOverride]
1295+
1296+
def proxy(self) -> MulticastTensor: # pyright: ignore[reportIncompatibleMethodOverride]
1297+
with proxy_tensor.disable_proxy_modes_tracing():
1298+
fake_mode = torch._C._unset_dispatch_mode( # pyright: ignore[reportAttributeAccessIssue]
1299+
torch._C._TorchDispatchModeKey.FAKE # pyright: ignore[reportAttributeAccessIssue]
1300+
)
1301+
try:
1302+
assert isinstance(self.element_types["tensor_like"], TensorType)
1303+
assert isinstance(self.element_types["dev_ptrs"], TensorType)
1304+
return MulticastTensor(
1305+
self.element_types["tensor_like"].proxy(),
1306+
self.element_types["dev_ptrs"].proxy(),
1307+
)
1308+
finally:
1309+
assert fake_mode is not None
1310+
torch._C._set_dispatch_mode(fake_mode) # pyright: ignore[reportAttributeAccessIssue]
1311+
1312+
def merge(self, other: TypeInfo) -> TypeInfo:
1313+
if isinstance(other, MulticastTensorType):
1314+
self_elements = self.element_types
1315+
other_elements = other.element_types
1316+
if set(self_elements.keys()) == set(other_elements.keys()):
1317+
return MulticastTensorType(
1318+
origin=other.origin,
1319+
element_types={
1320+
key: self_elements[key].merge(other_elements[key])
1321+
for key in self_elements
1322+
},
1323+
)
1324+
return super().merge(other)
1325+
1326+
def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]:
1327+
tensor_like_type = self.element_types["tensor_like"]
1328+
assert isinstance(tensor_like_type, TensorType)
1329+
size_like = tensor_like_type._device_indexing_size(key)
1330+
1331+
dev_ptrs_type = self.element_types["dev_ptrs"]
1332+
assert isinstance(dev_ptrs_type, TensorType)
1333+
multicast_size = list(dev_ptrs_type.fake_value.size())
1334+
1335+
return multicast_size + size_like
1336+
1337+
def propagate_setitem(
1338+
self, key: TypeInfo, value: TypeInfo, origin: Origin
1339+
) -> TypeInfo:
1340+
if origin.is_host():
1341+
warning(exc.TensorOperationInWrapper)
1342+
else:
1343+
lhs_shape = self._device_indexing_size(key)
1344+
lhs_rank = len(lhs_shape)
1345+
if isinstance(value, TensorType):
1346+
rhs_rank = value.fake_value.ndim
1347+
if lhs_rank != rhs_rank:
1348+
raise exc.RankMismatch(
1349+
lhs_rank,
1350+
rhs_rank,
1351+
f"LHS shape: {tuple(lhs_shape)}, RHS shape: {tuple(value.fake_value.shape)}",
1352+
)
1353+
elif isinstance(value, (NumericType, LiteralType)):
1354+
# Allow scalar assignment to tensor (broadcasts to tensor shape)
1355+
pass
1356+
else:
1357+
raise exc.RequiresTensorInAssignment(value)
1358+
return self
1359+
1360+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1361+
if origin.is_host():
1362+
warning(exc.TensorOperationInWrapper)
1363+
1364+
assert isinstance(self.element_types["tensor_like"], TensorType)
1365+
return TensorType(
1366+
origin,
1367+
self.element_types["tensor_like"]
1368+
.proxy()
1369+
.new_empty(self._device_indexing_size(key)),
1370+
)
1371+
1372+
12921373
class SliceType(CollectionType):
12931374
element_types: slice # pyright: ignore[reportIncompatibleVariableOverride]
12941375

@@ -1614,7 +1695,7 @@ def _assign(self, lhs: ast.AST, rhs: TypeInfo) -> None:
16141695
if isinstance(lhs, ast.Subscript):
16151696
# TODO(jansel): test different types of subscript
16161697
lhs_base_type = self.visit(lhs.value)
1617-
if isinstance(lhs_base_type, TensorType):
1698+
if isinstance(lhs_base_type, (TensorType, MulticastTensorType)):
16181699
self.visit(lhs) # need to populate shape info
16191700
lhs_base_type = lhs_base_type.propagate_setitem(
16201701
self.visit(lhs.slice), rhs, self.origin()

0 commit comments

Comments
 (0)