diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index f673d9a7..26a7ea16 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -429,13 +429,6 @@ def _body(self, body: list[ast.stmt]) -> None: for stmt in body: self.visit(stmt) - def _to_proxy(self, node: ast.AST) -> object: - assert isinstance(node, ExtendedAST) - type_info = node._type_info - if not type_info.contains_tensor(): - return type_info.proxy() - return self.visit(node) - def visit_BinOp(self, node: ast.BinOp) -> object: return _eval_binary(node.op, self.visit(node.left), self.visit(node.right)) @@ -793,15 +786,15 @@ def visit_Call(self, node: ast.Call) -> object: for arg in node.args: if isinstance(arg, ast.Starred): # pyre-ignore[6] - args.extend(self._to_proxy(arg.value)) + args.extend(self.visit(arg.value)) else: - args.append(self._to_proxy(arg)) + args.append(self.visit(arg)) for kwarg in node.keywords: if kwarg.arg is None: # pyre-ignore[6] - kwargs.update(self._to_proxy(kwarg.value)) + kwargs.update(self.visit(kwarg.value)) else: - kwargs[kwarg.arg] = self._to_proxy(kwarg.value) + kwargs[kwarg.arg] = self.visit(kwarg.value) if isinstance( (func_type_info := node.func._type_info), # pyre-ignore[16] diff --git a/test/test_atomic_add.py b/test/test_atomic_add.py index 4b86d946..635a4176 100644 --- a/test/test_atomic_add.py +++ b/test/test_atomic_add.py @@ -47,6 +47,15 @@ def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor) -> torch.Ten return x +@helion.kernel() +def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor: + """Test atomic_add where the index is a symbolic int""" + y = torch.zeros_like(x, device=x.device, dtype=torch.int32) + for tile in hl.tile(x.size(0)): + hl.atomic_add(y, [tile.begin], 1) + return y + + class TestAtomicOperations(TestCase): maxDiff = 16384 @@ -203,6 +212,18 @@ def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor): ) self.assertIn("Invalid memory semantic 'ERROR'", str(ctx.exception)) + def test_atomic_add_w_tile_attr(self): + """Test atomic_add where the index is a symbolic int""" + x = torch.randn(20, device=DEVICE) + code, result = code_and_output( + atomic_add_w_tile_attr, + (x,), + block_sizes=[2], + ) + + expected = torch.tensor([1, 0], device=DEVICE, dtype=torch.int32).repeat(10) + torch.testing.assert_close(result, expected) + if __name__ == "__main__": unittest.main()