Skip to content

Commit 3cab993

Browse files
committed
Fix visitCall in deviceIR. Always visit argument nodes
ghstack-source-id: fefa4f8 Pull Request resolved: #178 stack-info: PR: #180, branch: joydddd/stack/2
1 parent 46b617d commit 3cab993

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

helion/_compiler/device_ir.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,6 @@ def _body(self, body: list[ast.stmt]) -> None:
429429
for stmt in body:
430430
self.visit(stmt)
431431

432-
def _to_proxy(self, node: ast.AST) -> object:
433-
assert isinstance(node, ExtendedAST)
434-
type_info = node._type_info
435-
if not type_info.contains_tensor():
436-
return type_info.proxy()
437-
return self.visit(node)
438-
439432
def visit_BinOp(self, node: ast.BinOp) -> object:
440433
return _eval_binary(node.op, self.visit(node.left), self.visit(node.right))
441434

@@ -793,15 +786,15 @@ def visit_Call(self, node: ast.Call) -> object:
793786
for arg in node.args:
794787
if isinstance(arg, ast.Starred):
795788
# pyre-ignore[6]
796-
args.extend(self._to_proxy(arg.value))
789+
args.extend(self.visit(arg.value))
797790
else:
798-
args.append(self._to_proxy(arg))
791+
args.append(self.visit(arg))
799792
for kwarg in node.keywords:
800793
if kwarg.arg is None:
801794
# pyre-ignore[6]
802-
kwargs.update(self._to_proxy(kwarg.value))
795+
kwargs.update(self.visit(kwarg.value))
803796
else:
804-
kwargs[kwarg.arg] = self._to_proxy(kwarg.value)
797+
kwargs[kwarg.arg] = self.visit(kwarg.value)
805798

806799
if isinstance(
807800
(func_type_info := node.func._type_info), # pyre-ignore[16]

test/test_atomic_add.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ def atomic_add_float_kernel(x: torch.Tensor, indices: torch.Tensor) -> torch.Ten
4747
return x
4848

4949

50+
@helion.kernel()
51+
def atomic_add_w_tile_attr(x: torch.Tensor) -> torch.Tensor:
52+
"""Test atomic_add where the index is a symbolic int"""
53+
y = torch.zeros_like(x, device=x.device, dtype=torch.int32)
54+
for tile in hl.tile(x.size(0)):
55+
hl.atomic_add(y, [tile.begin], 1)
56+
return y
57+
58+
5059
class TestAtomicOperations(TestCase):
5160
maxDiff = 16384
5261

@@ -203,6 +212,18 @@ def bad_atomic_add_kernel(x: torch.Tensor, y: torch.Tensor):
203212
)
204213
self.assertIn("Invalid memory semantic 'ERROR'", str(ctx.exception))
205214

215+
def test_atomic_add_w_tile_attr(self):
216+
"""Test atomic_add where the index is a symbolic int"""
217+
x = torch.randn(20, device=DEVICE)
218+
code, result = code_and_output(
219+
atomic_add_w_tile_attr,
220+
(x,),
221+
block_sizes=[2],
222+
)
223+
224+
expected = torch.tensor([1, 0], device=DEVICE, dtype=torch.int32).repeat(10)
225+
torch.testing.assert_close(result, expected)
226+
206227

207228
if __name__ == "__main__":
208229
unittest.main()

0 commit comments

Comments
 (0)