Skip to content

Commit 2d0cd8c

Browse files
committed
Add hl.reduce
stack-info: PR: #240, branch: jansel/stack/79
1 parent 82a7849 commit 2d0cd8c

File tree

6 files changed

+1163
-0
lines changed

6 files changed

+1163
-0
lines changed

helion/_compiler/device_ir.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,9 @@ def visit_Assign(self, node: ast.Assign) -> None:
740740
for t, v in zip(target.elts, value, strict=True):
741741
if isinstance(t, ast.Name):
742742
self._assign(t, v)
743+
elif isinstance(t, ast.Subscript):
744+
# Handle subscript targets in tuple unpacking (e.g., a[i], b[j] = tuple)
745+
self._assign_subscript(t, v)
743746
else:
744747
raise exc.InvalidAssignment
745748
return None
@@ -757,6 +760,23 @@ def visit_Assign(self, node: ast.Assign) -> None:
757760
target_origin = target.value._type_info.origin
758761
assert target_origin.is_host()
759762
val = self.visit(node.value)
763+
self._assign_subscript(target, val)
764+
765+
def _assign_subscript(self, target: ast.Subscript, val: object) -> None:
766+
"""Helper method to assign a value to a subscript target."""
767+
assert isinstance(target, ExtendedAST)
768+
lhs_type = target._type_info
769+
770+
# Validate that we're assigning to a tensor subscript
771+
from .type_propagation import TensorType
772+
773+
if not isinstance(lhs_type, TensorType):
774+
raise exc.NonTensorSubscriptAssign(lhs_type, type(val))
775+
776+
assert isinstance(target.value, ExtendedAST)
777+
target_origin = target.value._type_info.origin
778+
assert target_origin.is_host()
779+
760780
return hl.store(
761781
self.visit(target.value), self._subscript_slice_proxy(target.slice), val
762782
)

helion/_compiler/roll_reduction.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from helion.language._tracing_ops import _host_tensor
1616
from helion.language._tracing_ops import _if
1717
from helion.language.memory_ops import store
18+
from helion.language.reduce_ops import _reduce
1819

1920
if TYPE_CHECKING:
2021
from helion._compiler.compile_environment import BlockSizeInfo
@@ -69,6 +70,12 @@ def should_go_in_inner_graph(self, node: torch.fx.Node) -> bool:
6970
return False
7071
assert node.op == "call_function", f"Unsupported node type {node.op}"
7172

73+
if node.target is _reduce:
74+
# TODO(jansel): support rolling user-defined reductions
75+
raise NotImplementedError(
76+
"hl._reduce operations are not compatible with reduction rolling"
77+
)
78+
7279
if node.target in (_for_loop, _if):
7380
if node.target is _for_loop:
7481
graph_id, *_ = node.args

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .memory_ops import atomic_add as atomic_add
1212
from .memory_ops import load as load
1313
from .memory_ops import store as store
14+
from .reduce_ops import reduce as reduce
1415
from .scan_ops import associative_scan as associative_scan
1516
from .scan_ops import cumprod as cumprod
1617
from .scan_ops import cumsum as cumsum

0 commit comments

Comments
 (0)