Skip to content

Commit a71811e

Browse files
committed
Add hl.associative_scan
stack-info: PR: #239, branch: jansel/stack/78
1 parent a6d5031 commit a71811e

File tree

11 files changed

+1554
-30
lines changed

11 files changed

+1554
-30
lines changed

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
222222
),
223223
):
224224
return obj
225-
if isinstance(obj, types.FunctionType):
225+
# Handle functions and Kernel objects
226+
from ..runtime.kernel import Kernel
227+
228+
if isinstance(obj, (types.FunctionType, Kernel)):
229+
from .helper_function import extract_helper_function
226230
from .lift_closures import lift_closures
227231

228-
return lift_closures(obj, origin)
232+
fn = extract_helper_function(obj)
233+
return lift_closures(fn, origin)
229234
if isinstance(obj, ConstExpr):
230235
return obj.value
231236
if isinstance(obj, list):

helion/_compiler/device_function.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from ..runtime.config import Config
40+
from .device_ir import HelperFunctionGraphInfo
4041
from .generate_ast import GenerateAST
4142
from .program_id import ProgramIDs
4243

@@ -185,6 +186,10 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
185186
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
186187
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}
187188

189+
from .helper_function import HelperFunctionManager
190+
191+
self.helper_manager = HelperFunctionManager()
192+
188193
from .indexing_strategy import IndexingStrategy
189194
from .tile_dispatch import TileStrategyDispatch
190195

@@ -488,6 +493,16 @@ def dead_code_elimination(self) -> None:
488493
if v.name in args_to_remove:
489494
del cache[k]
490495

496+
def register_helper_function(
497+
self, helper_graph_info: HelperFunctionGraphInfo
498+
) -> None:
499+
"""Register a helper function to be generated at global scope."""
500+
self.helper_manager.register_helper_function(helper_graph_info)
501+
502+
def codegen_helper_functions(self) -> list[ast.stmt]:
503+
"""Generate helper function definitions at global scope."""
504+
return self.helper_manager.codegen_helper_functions()
505+
491506
def __enter__(self) -> None:
492507
try:
493508
tls.functions.append(self)

helion/_compiler/device_ir.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ast
44
import builtins
5-
from collections.abc import Callable
65
import contextlib
76
import dataclasses
87
import functools
@@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None:
339338
for graph_id, graph_info in enumerate([*self.graphs]):
340339
assert graph_id == graph_info.graph_id
341340
roller = ReductionRoller(self, rdim, graph_to_info)
342-
new_graph = roller.process(graph_info.graph)
341+
try:
342+
new_graph = roller.process(graph_info.graph)
343+
except NotImplementedError:
344+
first = False
345+
break
343346
new_graph_id = self.add_graph(
344347
new_graph, type(graph_info), **graph_info.kwargs()
345348
)
@@ -898,6 +901,26 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
898901
return device_ir
899902

900903

904+
@dataclasses.dataclass
905+
class HelperFunctionGraphInfo(NodeArgsGraphInfo):
906+
"""Graph info for helper functions in higher-order operations like associative_scan."""
907+
908+
_param_names: list[str] = dataclasses.field(default_factory=list)
909+
910+
@property
911+
def name(self) -> str:
912+
return f"helper_function_{self.graph_id}"
913+
914+
def find_input_nodes(self) -> list[torch.fx.Node]:
915+
"""Find all placeholder nodes (inputs) in the graph."""
916+
return self.graph.find_nodes(op="placeholder")
917+
918+
def codegen(self, state: CodegenState) -> list[object]:
919+
from .helper_function import codegen_helper_function_graph_info
920+
921+
return codegen_helper_function_graph_info(self, state)
922+
923+
901924
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
902925
"""
903926
Remove unnecessary tile_index nodes from the graph.

helion/_compiler/generate_ast.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .ast_extension import statement_from_string
1818
from .compile_environment import CompileEnvironment
1919
from .device_function import DeviceFunction
20+
from .helper_function import CodegenInterface
2021
from .inductor_lowering import CodegenState
2122
from .inductor_lowering import codegen_call_with_graph
2223
from .program_id import ForEachProgramID
@@ -32,19 +33,25 @@
3233
from .type_propagation import TensorType
3334

3435

35-
class GenerateAST(NodeVisitor):
36+
class GenerateAST(NodeVisitor, CodegenInterface):
3637
def __init__(self, func: HostFunction, config: Config) -> None:
37-
super().__init__()
38+
# Initialize NodeVisitor first
39+
NodeVisitor.__init__(self)
40+
41+
# Initialize our attributes
3842
self.host_function = func
3943
self.host_statements: list[ast.AST] = []
4044
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
4145
self.on_device = False
42-
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
4346
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4447
collections.defaultdict(list)
4548
)
4649
self.next_else_block: list[ast.AST] | None = None
4750

51+
# Now create device function and initialize CodegenInterface
52+
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
53+
CodegenInterface.__init__(self, self.device_function)
54+
4855
def offset_var(self, block_idx: int) -> str:
4956
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)
5057

@@ -63,9 +70,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
6370
stmt = statement_from_string(stmt)
6471
self.statements_stack[-1].append(stmt)
6572

66-
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
67-
return self.device_function.unique_name(prefix, dce=dce)
68-
6973
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
7074
if isinstance(expr, ast.Name):
7175
return expr
@@ -413,6 +417,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
413417
result = ast.Module(
414418
[
415419
*func.codegen_imports(),
420+
*codegen.device_function.codegen_helper_functions(),
416421
*kernel_def,
417422
host_def,
418423
precompile_def,

helion/_compiler/helper_function.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from abc import abstractmethod
5+
import ast
6+
from typing import TYPE_CHECKING
7+
from typing import cast
8+
9+
from .ast_extension import create
10+
from .ast_extension import create_arg
11+
from .ast_extension import create_arguments
12+
from .ast_extension import expr_from_string
13+
from .ast_extension import statement_from_string
14+
15+
if TYPE_CHECKING:
16+
import types
17+
18+
from .device_function import DeviceFunction
19+
from .device_ir import HelperFunctionGraphInfo
20+
21+
22+
class CodegenInterface(ABC):
23+
"""Abstract base class for codegen interfaces used by GraphInterpreter."""
24+
25+
def __init__(self, device_function: DeviceFunction) -> None:
26+
self.device_function = device_function
27+
28+
@abstractmethod
29+
def add_statement(self, stmt: ast.AST | str | None) -> None:
30+
"""Add a statement to the generated code."""
31+
32+
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
33+
"""Generate a temporary variable name."""
34+
return self.device_function.unique_name(prefix, dce=dce)
35+
36+
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
37+
"""Lift an expression to a temporary variable if needed."""
38+
if isinstance(expr, ast.Name):
39+
return expr
40+
varname = self.tmpvar(dce=dce, prefix=prefix)
41+
self.add_statement(statement_from_string(f"{varname} = expr", expr=expr))
42+
return create(ast.Name, id=varname, ctx=ast.Load())
43+
44+
45+
def extract_helper_function(helper_fn: object) -> types.FunctionType:
46+
"""Extract the actual function from a Kernel object or return as-is.
47+
48+
This utility function centralizes the logic for handling both regular functions
49+
and Kernel objects that wrap functions.
50+
"""
51+
from ..runtime.kernel import Kernel
52+
53+
# pyre-ignore[16]: We check isinstance before accessing .fn
54+
return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn
55+
56+
57+
class HelperCodegen(CodegenInterface):
58+
"""Codegen wrapper for helper function generation."""
59+
60+
def __init__(self, device_function: DeviceFunction) -> None:
61+
super().__init__(device_function)
62+
63+
def add_statement(self, stmt: ast.AST | str | None) -> None:
64+
if stmt is not None:
65+
if isinstance(stmt, str):
66+
stmt = statement_from_string(stmt)
67+
self.device_function.body.append(stmt)
68+
69+
70+
class HelperFunctionManager:
71+
"""Manages helper function registration and code generation."""
72+
73+
def __init__(self) -> None:
74+
self.helper_functions: dict[str, HelperFunctionGraphInfo] = {}
75+
76+
def register_helper_function(
77+
self, helper_graph_info: HelperFunctionGraphInfo
78+
) -> None:
79+
"""Register a helper function to be generated at global scope."""
80+
self.helper_functions[helper_graph_info.name] = helper_graph_info
81+
82+
def codegen_helper_functions(self) -> list[ast.stmt]:
83+
"""Generate helper function definitions at global scope."""
84+
helper_defs = []
85+
for helper_graph_info in self.helper_functions.values():
86+
# Determine the number of parameters from the graph
87+
input_nodes = helper_graph_info.find_input_nodes()
88+
89+
# Generate argument list with consistent names
90+
args = []
91+
param_names = []
92+
for i in range(len(input_nodes)):
93+
arg_name = f"param_{i}"
94+
args.append(create_arg(arg_name))
95+
param_names.append(arg_name)
96+
97+
# Store parameter names for use in body generation
98+
helper_graph_info._param_names = param_names
99+
100+
# Process the FX graph to generate the correct helper function body
101+
func_body = self._codegen_helper_function_body(helper_graph_info)
102+
103+
# Generate the function structure with @triton.jit decorator
104+
func_def = create(
105+
ast.FunctionDef,
106+
name=helper_graph_info.name,
107+
args=create_arguments(args),
108+
body=func_body,
109+
decorator_list=[expr_from_string("triton.jit")],
110+
type_params=[],
111+
)
112+
113+
helper_defs.append(func_def)
114+
115+
return helper_defs
116+
117+
def _codegen_helper_function_body(
118+
self, helper_graph_info: HelperFunctionGraphInfo
119+
) -> list[ast.stmt]:
120+
"""Generate the body of a helper function by processing its FX graph."""
121+
temp_device_function = self._create_temp_device_function(helper_graph_info)
122+
param_args = self._create_parameter_args(helper_graph_info)
123+
124+
with temp_device_function:
125+
results = self._process_helper_graph(
126+
helper_graph_info, temp_device_function, param_args
127+
)
128+
statements = temp_device_function.body.copy()
129+
self._ensure_return_statement(statements, results, helper_graph_info.name)
130+
131+
return cast("list[ast.stmt]", statements)
132+
133+
def _create_temp_device_function(
134+
self, helper_graph_info: HelperFunctionGraphInfo
135+
) -> DeviceFunction:
136+
"""Create a temporary DeviceFunction for helper function generation."""
137+
# Import here to avoid circular imports
138+
from .device_function import DeviceFunction
139+
140+
current = DeviceFunction.current()
141+
142+
return DeviceFunction(
143+
name=f"temp_{helper_graph_info.name}",
144+
config=current.config,
145+
codegen=current.codegen,
146+
)
147+
148+
def _create_parameter_args(
149+
self, helper_graph_info: HelperFunctionGraphInfo
150+
) -> list[ast.AST]:
151+
"""Create parameter AST nodes for the helper function."""
152+
param_names = helper_graph_info._param_names
153+
return [expr_from_string(param_name) for param_name in param_names]
154+
155+
def _process_helper_graph(
156+
self,
157+
helper_graph_info: HelperFunctionGraphInfo,
158+
temp_device_function: DeviceFunction,
159+
param_args: list[ast.AST],
160+
) -> object:
161+
"""Process the graph using the existing interpreter infrastructure."""
162+
from .inductor_lowering import GraphInterpreter
163+
164+
helper_codegen = HelperCodegen(temp_device_function)
165+
interpreter = GraphInterpreter(helper_graph_info.graph, helper_codegen)
166+
return interpreter.run(*param_args)
167+
168+
def _ensure_return_statement(
169+
self, statements: list[ast.AST], results: object, function_name: str
170+
) -> None:
171+
"""Ensure the function body has a proper return statement."""
172+
if statements and isinstance(statements[-1], ast.Return):
173+
return
174+
175+
if isinstance(results, ast.AST):
176+
statements.append(create(ast.Return, value=results))
177+
elif isinstance(results, (list, tuple)) and all(
178+
isinstance(r, ast.AST) for r in results
179+
):
180+
tuple_ast = create(ast.Tuple, elts=list(results), ctx=ast.Load())
181+
statements.append(create(ast.Return, value=tuple_ast))
182+
else:
183+
raise RuntimeError(
184+
f"Helper function {function_name} produced invalid result: {type(results)} {results}"
185+
)
186+
187+
188+
def codegen_helper_function_graph_info(
189+
helper_graph_info: HelperFunctionGraphInfo, state: object
190+
) -> list[object]:
191+
"""Generate code for HelperFunctionGraphInfo objects."""
192+
from .inductor_lowering import CodegenState
193+
from .inductor_lowering import codegen_call_with_graph
194+
195+
if not isinstance(state, CodegenState):
196+
raise TypeError(f"Expected CodegenState, got {type(state)}")
197+
198+
# For helper functions, we need to inline the function body
199+
# The helper function takes variable arguments and returns their combination
200+
201+
# Generate temporary variable names for the helper function arguments
202+
# Use the graph's input nodes to determine the number of parameters
203+
input_nodes = helper_graph_info.find_input_nodes()
204+
args: list[ast.AST] = []
205+
206+
for i in range(len(input_nodes)):
207+
var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}")
208+
args.append(create(ast.Name, id=var_name, ctx=ast.Load()))
209+
210+
# Generate the helper function call
211+
return codegen_call_with_graph(state.codegen, helper_graph_info.graph, args)

0 commit comments

Comments
 (0)