From fc3c3dfc54539da27715b7709ede6aeb5c2005bc Mon Sep 17 00:00:00 2001 From: joydddd Date: Fri, 27 Jun 2025 15:53:09 -0700 Subject: [PATCH] Fix config flatten spec for tile.id stack-info: PR: https://github.com/pytorch-labs/helion/pull/224, branch: joydddd/stack/7 --- helion/language/tile_ops.py | 1 + test/test_misc.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index cfdf84bf..7f5691d0 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -151,6 +151,7 @@ def tile_id(tile: Tile) -> int: @_decorators.register_fake(tile_id) def _(tile: torch.SymInt) -> torch.SymInt: + _disable_flatten_get_tile(tile) # update config spec if needed assert isinstance(tile, torch.SymInt) return CompileEnvironment.current().cached_create_unbacked_symint(("tile_id", tile)) diff --git a/test/test_misc.py b/test/test_misc.py index 0f3ed3b8..054a4520 100644 --- a/test/test_misc.py +++ b/test/test_misc.py @@ -274,7 +274,7 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass def test_config_flatten_issue(self): @helion.kernel(use_default_config=True) - def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor: + def test_tile_begin(x: torch.Tensor) -> torch.Tensor: out = torch.zeros_like(x, dtype=torch.int32) for tile_m, tile_n in hl.tile(x.size()): out[tile_m.begin, tile_n.begin] = 1 @@ -282,8 +282,34 @@ def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor: x = torch.randn(64, 64, device="cuda") config = helion.Config(block_sizes=[16, 16]) - test_tile_id_atomic_add.bind((x,)).to_triton_code(config) - result = test_tile_id_atomic_add.bind((x,)).compile_config(config)(x) + test_tile_begin.bind((x,)).to_triton_code(config) + result = test_tile_begin.bind((x,)).compile_config(config)(x) + self.assertEqual(result.sum().item(), 16) + + @helion.kernel(use_default_config=True) + def test_tile_end(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x, dtype=torch.int32) + for tile_m, tile_n in hl.tile(x.size()): + out[tile_m.end, tile_n.end] = 1 + return out + + x = torch.randn(64, 64, device="cuda") + config = helion.Config(block_sizes=[16, 16]) + test_tile_end.bind((x,)).to_triton_code(config) + result = test_tile_end.bind((x,)).compile_config(config)(x) + self.assertEqual(result.sum().item(), 12) + + @helion.kernel(use_default_config=True) + def test_tile_id(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x, dtype=torch.int32) + for tile_m, tile_n in hl.tile(x.size()): + out[tile_m.id, tile_n.id] = 1 + return out + + x = torch.randn(64, 64, device="cuda") + config = helion.Config(block_sizes=[16, 16]) + test_tile_id.bind((x,)).to_triton_code(config) + result = test_tile_id.bind((x,)).compile_config(config)(x) self.assertEqual(result.sum().item(), 16)