Skip to content

Commit d507981

Browse files
authored
Generalize workaround for unbacked size hints (#159)
1 parent 16740ac commit d507981

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

helion/_compiler/compile_environment.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
255255

256256
def size_hint(self, n: int | torch.SymInt) -> int:
257257
if isinstance(n, torch.SymInt):
258+
expr = n._sympy_()
259+
if any(s.name.startswith("u") for s in expr.free_symbols):
260+
# If the size is a symbolic expression with unbacked symbols, then the shape environment
261+
# hint will be wrong since we assign a default value to unbacked symbols. Return a default hint.
262+
return 8192
258263
# pyre-ignore[6]
259264
return int(self.shape_env.size_hint(n._sympy_()))
260265
assert isinstance(n, int)

helion/_compiler/type_propagation.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -987,13 +987,7 @@ def _get_hint(numel: int | torch.SymInt | AutoSize | None) -> int:
987987
if numel is None or isinstance(numel, AutoSize):
988988
# For data-dependent sizes, use arbitrary hint of 8192
989989
return 8192
990-
991-
hint = CompileEnvironment.current().size_hint(numel)
992-
# If the hint is invalid (like 0), use a reasonable default
993-
# This can happen when other hints cancel out in expressions
994-
if hint <= 1:
995-
return 8192
996-
return hint
990+
return CompileEnvironment.current().size_hint(numel)
997991

998992

999993
class TileIndexType(TypeInfo):

0 commit comments

Comments
 (0)