From 78b4d3a220ebf6c0407ae4350431f9aeea66da0d Mon Sep 17 00:00:00 2001 From: joydddd Date: Mon, 16 Jun 2025 15:13:08 -0700 Subject: [PATCH] Add lowering for Constant assignment stack-info: PR: https://github.com/pytorch-labs/helion/pull/187, branch: joydddd/stack/3 --- helion/_compiler/inductor_lowering.py | 2 ++ test/test_indexing.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 780ec810..51ec9397 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -988,6 +988,8 @@ def proxy_arg(self, i: int) -> object: def ast_arg(self, i: int) -> ast.AST: rv = self.ast_args[i] + if isinstance(rv, int | float | bool): + rv = ast.Constant(value=rv) assert isinstance(rv, ast.AST), "TODO: convert nested/defaults" return rv diff --git a/test/test_indexing.py b/test/test_indexing.py index fbb07741..e1f385df 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -279,6 +279,21 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor: expected = torch.full_like(x, 1, dtype=torch.int32) torch.testing.assert_close(result, expected) + def test_assign_int(self): + @helion.kernel + def fn(x: torch.Tensor) -> torch.Tensor: + for tile in hl.tile(x.size(0)): + x[tile] = 1 + return x + + x = torch.zeros([200], device=DEVICE) + expected = torch.ones_like(x) + code, result = code_and_output( + fn, + (x,), + ) + torch.testing.assert_close(result, expected) + def test_atomic_add_symint(self): @helion.kernel(config={"block_size": 32}) def fn(x: torch.Tensor) -> torch.Tensor: