Skip to content

Commit 8d6fa33

Browse files
committed
Add hl.associative_scan
stack-info: PR: #239, branch: jansel/stack/78
1 parent 5c8e35b commit 8d6fa33

File tree

11 files changed

+1630
-30
lines changed

11 files changed

+1630
-30
lines changed

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
222222
),
223223
):
224224
return obj
225-
if isinstance(obj, types.FunctionType):
225+
# Handle functions and Kernel objects
226+
from ..runtime.kernel import Kernel
227+
228+
if isinstance(obj, (types.FunctionType, Kernel)):
229+
from .helper_function import extract_helper_function
226230
from .lift_closures import lift_closures
227231

228-
return lift_closures(obj, origin)
232+
fn = extract_helper_function(obj)
233+
return lift_closures(fn, origin)
229234
if isinstance(obj, ConstExpr):
230235
return obj.value
231236
if isinstance(obj, list):

helion/_compiler/device_function.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from ..runtime.config import Config
40+
from .device_ir import HelperFunctionGraphInfo
4041
from .generate_ast import GenerateAST
4142
from .program_id import ProgramIDs
4243

@@ -185,6 +186,10 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
185186
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
186187
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}
187188

189+
from .helper_function import HelperFunctionManager
190+
191+
self.helper_manager = HelperFunctionManager()
192+
188193
from .indexing_strategy import IndexingStrategy
189194
from .tile_dispatch import TileStrategyDispatch
190195

@@ -488,6 +493,16 @@ def dead_code_elimination(self) -> None:
488493
if v.name in args_to_remove:
489494
del cache[k]
490495

496+
def register_helper_function(
497+
self, helper_graph_info: HelperFunctionGraphInfo
498+
) -> None:
499+
"""Register a helper function to be generated at global scope."""
500+
self.helper_manager.register_helper_function(helper_graph_info)
501+
502+
def codegen_helper_functions(self) -> list[ast.stmt]:
503+
"""Generate helper function definitions at global scope."""
504+
return self.helper_manager.codegen_helper_functions()
505+
491506
def __enter__(self) -> None:
492507
try:
493508
tls.functions.append(self)

helion/_compiler/device_ir.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ast
44
import builtins
5-
from collections.abc import Callable
65
import contextlib
76
import dataclasses
87
import functools
@@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None:
339338
for graph_id, graph_info in enumerate([*self.graphs]):
340339
assert graph_id == graph_info.graph_id
341340
roller = ReductionRoller(self, rdim, graph_to_info)
342-
new_graph = roller.process(graph_info.graph)
341+
try:
342+
new_graph = roller.process(graph_info.graph)
343+
except NotImplementedError:
344+
first = False
345+
break
343346
new_graph_id = self.add_graph(
344347
new_graph, type(graph_info), **graph_info.kwargs()
345348
)
@@ -898,6 +901,26 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
898901
return device_ir
899902

900903

904+
@dataclasses.dataclass
905+
class HelperFunctionGraphInfo(NodeArgsGraphInfo):
906+
"""Graph info for helper functions in higher-order operations like associative_scan."""
907+
908+
_param_names: list[str] = dataclasses.field(default_factory=list)
909+
910+
@property
911+
def name(self) -> str:
912+
return f"helper_function_{self.graph_id}"
913+
914+
def find_input_nodes(self) -> list[torch.fx.Node]:
915+
"""Find all placeholder nodes (inputs) in the graph."""
916+
return self.graph.find_nodes(op="placeholder")
917+
918+
def codegen(self, state: CodegenState) -> list[object]:
919+
from .helper_function import codegen_helper_function_graph_info
920+
921+
return codegen_helper_function_graph_info(self, state)
922+
923+
901924
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
902925
"""
903926
Remove unnecessary tile_index nodes from the graph.

helion/_compiler/generate_ast.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .ast_extension import statement_from_string
1818
from .compile_environment import CompileEnvironment
1919
from .device_function import DeviceFunction
20+
from .helper_function import CodegenInterface
2021
from .inductor_lowering import CodegenState
2122
from .inductor_lowering import codegen_call_with_graph
2223
from .program_id import ForEachProgramID
@@ -32,19 +33,25 @@
3233
from .type_propagation import TensorType
3334

3435

35-
class GenerateAST(NodeVisitor):
36+
class GenerateAST(NodeVisitor, CodegenInterface):
3637
def __init__(self, func: HostFunction, config: Config) -> None:
37-
super().__init__()
38+
# Initialize NodeVisitor first
39+
NodeVisitor.__init__(self)
40+
41+
# Initialize our attributes
3842
self.host_function = func
3943
self.host_statements: list[ast.AST] = []
4044
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
4145
self.on_device = False
42-
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
4346
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4447
collections.defaultdict(list)
4548
)
4649
self.next_else_block: list[ast.AST] | None = None
4750

51+
# Now create device function and initialize CodegenInterface
52+
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
53+
CodegenInterface.__init__(self, self.device_function)
54+
4855
def offset_var(self, block_idx: int) -> str:
4956
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)
5057

@@ -63,9 +70,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
6370
stmt = statement_from_string(stmt)
6471
self.statements_stack[-1].append(stmt)
6572

66-
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
67-
return self.device_function.unique_name(prefix, dce=dce)
68-
6973
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
7074
if isinstance(expr, ast.Name):
7175
return expr
@@ -413,6 +417,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
413417
result = ast.Module(
414418
[
415419
*func.codegen_imports(),
420+
*codegen.device_function.codegen_helper_functions(),
416421
*kernel_def,
417422
host_def,
418423
precompile_def,

0 commit comments

Comments
 (0)