Skip to content

Commit f2a137b

Browse files
authored
Fix block size variable handling and atomic operations with symints (#177)
1 parent 808e1d6 commit f2a137b

File tree

4 files changed

+29
-27
lines changed

4 files changed

+29
-27
lines changed

helion/language/_tracing_ops.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ def _(state: CodegenState) -> ast.AST:
3939
val = state.fx_node.meta["val"]
4040
assert isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)), val
4141
if (block_idx := CompileEnvironment.current().get_block_id(val)) is not None:
42-
if state.device_function.block_size_var(block_idx) is None:
43-
# this should be unused
44-
return expr_from_string("block_size_var_optimized_away")
42+
block_size_var = state.device_function.block_size_var(block_idx)
43+
if block_size_var is None:
44+
return expr_from_string("1")
45+
return expr_from_string(block_size_var)
4546
return state.codegen.lift(
4647
expr_from_string(state.sympy_expr(val._sympy_())),
4748
dce=True,

helion/language/memory_ops.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from typing import TYPE_CHECKING
55

66
import torch
7+
from torch._inductor.codegen.simd import constant_repr
78
from torch.fx import has_side_effect
89

910
from .. import exc
11+
from .._compiler.ast_extension import expr_from_string
1012
from .._compiler.indexing_strategy import SubscriptIndexing
1113
from . import _decorators
1214

@@ -170,26 +172,19 @@ def _(
170172

171173
@_decorators.codegen(atomic_add)
172174
def _(state: CodegenState) -> ast.AST:
173-
import ast
174-
175-
from .._compiler.ast_extension import expr_from_string
176-
177175
target = state.proxy_arg(0)
178176
index = state.proxy_arg(1)
179-
value = state.proxy_arg(2)
180-
sem = expr_from_string(f"'{state.proxy_arg(3)}'")
177+
sem = expr_from_string(repr(state.proxy_arg(3)))
181178

182179
assert isinstance(target, torch.Tensor)
183-
assert isinstance(index, (list))
180+
assert isinstance(index, list)
184181

185182
indices = SubscriptIndexing.create(state, target, index)
186183
name = state.device_function.tensor_arg(target).name
187184

188-
value_expr = (
189-
state.ast_args[2]
190-
if isinstance(value, torch.Tensor)
191-
else ast.Constant(value=value)
192-
)
185+
value_expr = state.ast_args[2]
186+
if isinstance(value_expr, (int, float, bool)):
187+
value_expr = expr_from_string(constant_repr(value_expr))
193188
assert isinstance(value_expr, ast.AST)
194189
return expr_from_string(
195190
f"tl.atomic_add({name} + offset, value, mask=mask, sem=sem)",

helion/language/tile_ops.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,8 @@ def tile_block_size(tile: Tile) -> int:
124124

125125
@_decorators.register_fake(tile_block_size)
126126
def _(tile: torch.SymInt) -> torch.SymInt:
127-
assert isinstance(tile, torch.SymInt)
128-
return CompileEnvironment.current().create_unbacked_symint()
129-
130-
131-
@_decorators.codegen(tile_block_size)
132-
def _(state: CodegenState) -> ast.AST:
133-
index = _get_tile_index(state)
134-
block_size_var = state.device_function.block_size_var(index)
127+
return tile
135128

136-
if block_size_var is not None:
137-
return expr_from_string(block_size_var)
138129

139-
# Final fallback for grid tiles with block_size=1
140-
return expr_from_string("1")
130+
# since we return tile above, no codegen is needed for this function.
131+
# codegen is handled in _get_symnode()

test/test_indexing.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,21 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
279279
expected = torch.full_like(x, 1, dtype=torch.int32)
280280
torch.testing.assert_close(result, expected)
281281

282+
def test_atomic_add_symint(self):
283+
@helion.kernel(config={"block_size": 32})
284+
def fn(x: torch.Tensor) -> torch.Tensor:
285+
for tile in hl.tile(x.size(0)):
286+
hl.atomic_add(x, [tile], tile.block_size + 1)
287+
return x
288+
289+
x = torch.zeros([200], device=DEVICE)
290+
expected = x + 33
291+
code, result = code_and_output(
292+
fn,
293+
(x,),
294+
)
295+
torch.testing.assert_close(result, expected)
296+
282297

283298
if __name__ == "__main__":
284299
unittest.main()

0 commit comments

Comments
 (0)