Skip to content

Commit 0377217

Browse files
authored
Fix config flatten spec for tile.id (#224)
1 parent 0ba5311 commit 0377217

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

helion/language/tile_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def tile_id(tile: Tile) -> int:
151151

152152
@_decorators.register_fake(tile_id)
153153
def _(tile: torch.SymInt) -> torch.SymInt:
154+
_disable_flatten_get_tile(tile) # update config spec if needed
154155
assert isinstance(tile, torch.SymInt)
155156
return CompileEnvironment.current().cached_create_unbacked_symint(("tile_id", tile))
156157

test/test_misc.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,42 @@ def _kernel_make_precompiler(a_list, b_dict, b_tuple, c_named_tuple, d_dataclass
274274

275275
def test_config_flatten_issue(self):
276276
@helion.kernel(use_default_config=True)
277-
def test_tile_id_atomic_add(x: torch.Tensor) -> torch.Tensor:
277+
def test_tile_begin(x: torch.Tensor) -> torch.Tensor:
278278
out = torch.zeros_like(x, dtype=torch.int32)
279279
for tile_m, tile_n in hl.tile(x.size()):
280280
out[tile_m.begin, tile_n.begin] = 1
281281
return out
282282

283283
x = torch.randn(64, 64, device="cuda")
284284
config = helion.Config(block_sizes=[16, 16])
285-
test_tile_id_atomic_add.bind((x,)).to_triton_code(config)
286-
result = test_tile_id_atomic_add.bind((x,)).compile_config(config)(x)
285+
test_tile_begin.bind((x,)).to_triton_code(config)
286+
result = test_tile_begin.bind((x,)).compile_config(config)(x)
287+
self.assertEqual(result.sum().item(), 16)
288+
289+
@helion.kernel(use_default_config=True)
290+
def test_tile_end(x: torch.Tensor) -> torch.Tensor:
291+
out = torch.zeros_like(x, dtype=torch.int32)
292+
for tile_m, tile_n in hl.tile(x.size()):
293+
out[tile_m.end, tile_n.end] = 1
294+
return out
295+
296+
x = torch.randn(64, 64, device="cuda")
297+
config = helion.Config(block_sizes=[16, 16])
298+
test_tile_end.bind((x,)).to_triton_code(config)
299+
result = test_tile_end.bind((x,)).compile_config(config)(x)
300+
self.assertEqual(result.sum().item(), 12)
301+
302+
@helion.kernel(use_default_config=True)
303+
def test_tile_id(x: torch.Tensor) -> torch.Tensor:
304+
out = torch.zeros_like(x, dtype=torch.int32)
305+
for tile_m, tile_n in hl.tile(x.size()):
306+
out[tile_m.id, tile_n.id] = 1
307+
return out
308+
309+
x = torch.randn(64, 64, device="cuda")
310+
config = helion.Config(block_sizes=[16, 16])
311+
test_tile_id.bind((x,)).to_triton_code(config)
312+
result = test_tile_id.bind((x,)).compile_config(config)(x)
287313
self.assertEqual(result.sum().item(), 16)
288314

289315

0 commit comments

Comments
 (0)