Skip to content

Commit bb7bfa1

Browse files
committed
Add hl.associative_scan
stack-info: PR: #239, branch: jansel/stack/78
1 parent a6d5031 commit bb7bfa1

12 files changed

+892
-14
lines changed

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
222222
),
223223
):
224224
return obj
225-
if isinstance(obj, types.FunctionType):
225+
# Handle functions and Kernel objects
226+
from ..runtime.kernel import Kernel
227+
228+
if isinstance(obj, (types.FunctionType, Kernel)):
229+
from .helper_function import extract_helper_function
226230
from .lift_closures import lift_closures
227231

228-
return lift_closures(obj, origin)
232+
fn = extract_helper_function(obj)
233+
return lift_closures(fn, origin)
229234
if isinstance(obj, ConstExpr):
230235
return obj.value
231236
if isinstance(obj, list):

helion/_compiler/device_function.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from ..runtime.config import Config
40+
from .device_ir import HelperFunctionGraphInfo
4041
from .generate_ast import GenerateAST
4142
from .program_id import ProgramIDs
4243

@@ -185,6 +186,8 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
185186
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
186187
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}
187188

189+
self.helper_functions: dict[str, HelperFunctionGraphInfo] = {}
190+
188191
from .indexing_strategy import IndexingStrategy
189192
from .tile_dispatch import TileStrategyDispatch
190193

@@ -488,6 +491,90 @@ def dead_code_elimination(self) -> None:
488491
if v.name in args_to_remove:
489492
del cache[k]
490493

494+
def register_helper_function(
495+
self, helper_graph_info: HelperFunctionGraphInfo
496+
) -> None:
497+
"""Register a helper function to be generated at global scope."""
498+
self.helper_functions[helper_graph_info.name] = helper_graph_info
499+
500+
def codegen_helper_functions(self) -> list[ast.stmt]:
501+
"""Generate helper function definitions at global scope."""
502+
helper_defs = []
503+
for helper_graph_info in self.helper_functions.values():
504+
# Determine the number of parameters from the graph
505+
input_nodes = helper_graph_info.find_input_nodes()
506+
507+
# Generate argument list with consistent names
508+
args = []
509+
param_names = []
510+
for i in range(len(input_nodes)):
511+
arg_name = f"param_{i}"
512+
args.append(create_arg(arg_name))
513+
param_names.append(arg_name)
514+
515+
# Store parameter names for use in body generation
516+
helper_graph_info._param_names = param_names
517+
518+
# Process the FX graph to generate the correct helper function body
519+
func_body = self._codegen_helper_function_body(helper_graph_info)
520+
521+
# Generate the function structure with @triton.jit decorator
522+
func_def = create(
523+
ast.FunctionDef,
524+
name=helper_graph_info.name,
525+
args=create_arguments(args),
526+
body=func_body,
527+
decorator_list=[expr_from_string("triton.jit")],
528+
type_params=[],
529+
)
530+
531+
helper_defs.append(func_def)
532+
533+
return helper_defs
534+
535+
def _codegen_helper_function_body(
536+
self, helper_graph_info: HelperFunctionGraphInfo
537+
) -> list[ast.stmt]:
538+
"""Generate the body of a helper function by processing its FX graph."""
539+
from .helper_function import HelperCodegen
540+
from .inductor_lowering import GraphInterpreter
541+
542+
# Create a temporary DeviceFunction for generating the helper function
543+
temp_device_function = DeviceFunction(
544+
name=f"temp_{helper_graph_info.name}",
545+
config=self.config,
546+
codegen=self.codegen,
547+
)
548+
549+
# Use the parameter names from the function definition
550+
param_names = helper_graph_info._param_names
551+
552+
# Create parameter AST nodes for the helper function
553+
param_args = []
554+
for param_name in param_names:
555+
param_args.append(expr_from_string(param_name))
556+
557+
# Process the graph using the existing interpreter infrastructure
558+
with temp_device_function:
559+
helper_codegen = HelperCodegen(temp_device_function)
560+
interpreter = GraphInterpreter(helper_graph_info.graph, helper_codegen)
561+
results = interpreter.run(*param_args)
562+
563+
# Get the generated statements from the temporary device function
564+
statements = temp_device_function.body.copy()
565+
566+
# Ensure there's a return statement
567+
if not statements or not isinstance(statements[-1], ast.Return):
568+
if isinstance(results, ast.AST):
569+
statements.append(create(ast.Return, value=results))
570+
else:
571+
# This should not happen in normal operation
572+
raise RuntimeError(
573+
f"Helper function {helper_graph_info.name} did not produce a valid result"
574+
)
575+
576+
return cast("list[ast.stmt]", statements)
577+
491578
def __enter__(self) -> None:
492579
try:
493580
tls.functions.append(self)

helion/_compiler/device_ir.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ast
44
import builtins
5-
from collections.abc import Callable
65
import contextlib
76
import dataclasses
87
import functools
@@ -807,9 +806,94 @@ def visit_Call(self, node: ast.Call) -> object:
807806
else:
808807
func = self.visit(node.func)
809808

809+
# Special handling for associative_scan
810+
import helion.language as hl
811+
812+
if isinstance(
813+
(func_type_info := node.func._type_info),
814+
CallableType,
815+
) and (
816+
func_type_info.value is hl.associative_scan or func is hl.associative_scan
817+
):
818+
return self._handle_associative_scan(node, args, kwargs)
819+
810820
# pyre-ignore[6]
811821
return _CheckForIndexCalls.retry_call(func, args, kwargs)
812822

823+
def _handle_associative_scan(
824+
self, node: ast.Call, args: list[object], kwargs: dict[str, object]
825+
) -> object:
826+
"""Handle associative_scan calls by tracing the combine function as a subgraph."""
827+
from ..language import _tracing_ops
828+
829+
combine_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cast(
830+
"Callable[[torch.Tensor, torch.Tensor], torch.Tensor]", args[0]
831+
) # The combine function
832+
input_tensor = args[1] # The input tensor
833+
834+
# Extract other arguments from kwargs
835+
dim = kwargs.get("dim", 0)
836+
reverse = kwargs.get("reverse", False)
837+
838+
# Create a subgraph for the combine function
839+
def run_combine_subgraph(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
840+
# This will trace the combine function
841+
# Check if combine_fn is a Kernel object and extract the underlying function
842+
from .helper_function import extract_helper_function
843+
844+
actual_fn = extract_helper_function(combine_fn)
845+
return actual_fn(a, b)
846+
847+
# Create fake inputs for the combine function based on input tensor element type
848+
if isinstance(input_tensor, torch.Tensor):
849+
fake_a = torch.empty(
850+
[1], dtype=input_tensor.dtype, device=input_tensor.device
851+
)
852+
fake_b = torch.empty(
853+
[1], dtype=input_tensor.dtype, device=input_tensor.device
854+
)
855+
else:
856+
# Fallback for when input_tensor is a proxy
857+
fake_a = torch.empty([1], dtype=torch.float32)
858+
fake_b = torch.empty([1], dtype=torch.float32)
859+
860+
with self.disable_tracing() as tracer:
861+
combine_graph = proxy_tensor.make_fx(
862+
run_combine_subgraph, decomposition_table=select_decomp_table()
863+
)(fake_a, fake_b).graph
864+
865+
combine_graph_id = self.device_ir.add_graph(
866+
combine_graph,
867+
HelperFunctionGraphInfo,
868+
node_args=[], # The combine function doesn't use external args
869+
)
870+
871+
# Create the associative_scan tracing operation
872+
scan_args = (
873+
combine_graph_id,
874+
input_tensor,
875+
dim,
876+
reverse,
877+
)
878+
879+
proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args)
880+
proxy_out = tracer.create_proxy(
881+
"call_function",
882+
_tracing_ops._associative_scan,
883+
proxy_args,
884+
proxy_kwargs,
885+
)
886+
887+
# The output has the same shape as the input
888+
proxy_tensor.track_tensor_tree(
889+
input_tensor,
890+
proxy_out,
891+
constant=None,
892+
tracer=tracer,
893+
)
894+
895+
return proxy_out
896+
813897
def visit_Attribute(self, node: ast.Attribute) -> object:
814898
return getattr(self.visit(node.value), node.attr)
815899

@@ -898,6 +982,37 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
898982
return device_ir
899983

900984

985+
@dataclasses.dataclass
986+
class HelperFunctionGraphInfo(NodeArgsGraphInfo):
987+
"""Graph info for helper functions in higher-order operations like associative_scan."""
988+
989+
_param_names: list[str] = dataclasses.field(default_factory=list)
990+
991+
@property
992+
def name(self) -> str:
993+
return f"helper_function_{self.graph_id}"
994+
995+
def find_input_nodes(self) -> list[torch.fx.Node]:
996+
"""Find all placeholder nodes (inputs) in the graph."""
997+
return self.graph.find_nodes(op="placeholder")
998+
999+
def codegen(self, state: CodegenState) -> list[object]:
1000+
# For helper functions, we need to inline the function body
1001+
# The helper function takes variable arguments and returns their combination
1002+
1003+
# Generate temporary variable names for the helper function arguments
1004+
# Use the graph's input nodes to determine the number of parameters
1005+
input_nodes = self.find_input_nodes()
1006+
args: list[ast.AST] = []
1007+
1008+
for i in range(len(input_nodes)):
1009+
var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}")
1010+
args.append(create(ast.Name, id=var_name, ctx=ast.Load()))
1011+
1012+
# Generate the helper function call
1013+
return codegen_call_with_graph(state.codegen, self.graph, args)
1014+
1015+
9011016
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
9021017
"""
9031018
Remove unnecessary tile_index nodes from the graph.

helion/_compiler/generate_ast.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .ast_extension import statement_from_string
1818
from .compile_environment import CompileEnvironment
1919
from .device_function import DeviceFunction
20+
from .helper_function import CodegenInterface
2021
from .inductor_lowering import CodegenState
2122
from .inductor_lowering import codegen_call_with_graph
2223
from .program_id import ForEachProgramID
@@ -32,19 +33,25 @@
3233
from .type_propagation import TensorType
3334

3435

35-
class GenerateAST(NodeVisitor):
36+
class GenerateAST(NodeVisitor, CodegenInterface):
3637
def __init__(self, func: HostFunction, config: Config) -> None:
37-
super().__init__()
38+
# Initialize NodeVisitor first
39+
NodeVisitor.__init__(self)
40+
41+
# Initialize our attributes
3842
self.host_function = func
3943
self.host_statements: list[ast.AST] = []
4044
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
4145
self.on_device = False
42-
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
4346
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4447
collections.defaultdict(list)
4548
)
4649
self.next_else_block: list[ast.AST] | None = None
4750

51+
# Now create device function and initialize CodegenInterface
52+
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
53+
CodegenInterface.__init__(self, self.device_function)
54+
4855
def offset_var(self, block_idx: int) -> str:
4956
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)
5057

@@ -63,9 +70,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
6370
stmt = statement_from_string(stmt)
6471
self.statements_stack[-1].append(stmt)
6572

66-
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
67-
return self.device_function.unique_name(prefix, dce=dce)
68-
6973
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
7074
if isinstance(expr, ast.Name):
7175
return expr
@@ -413,6 +417,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
413417
result = ast.Module(
414418
[
415419
*func.codegen_imports(),
420+
*codegen.device_function.codegen_helper_functions(),
416421
*kernel_def,
417422
host_def,
418423
precompile_def,

helion/_compiler/helper_function.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC
4+
from abc import abstractmethod
5+
import ast
6+
from typing import TYPE_CHECKING
7+
8+
from .ast_extension import create
9+
from .ast_extension import statement_from_string
10+
11+
if TYPE_CHECKING:
12+
import types
13+
14+
from .device_function import DeviceFunction
15+
16+
17+
class CodegenInterface(ABC):
18+
"""Abstract base class for codegen interfaces used by GraphInterpreter."""
19+
20+
def __init__(self, device_function: DeviceFunction) -> None:
21+
self.device_function = device_function
22+
23+
@abstractmethod
24+
def add_statement(self, stmt: ast.AST | str | None) -> None:
25+
"""Add a statement to the generated code."""
26+
27+
def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
28+
"""Generate a temporary variable name."""
29+
return self.device_function.unique_name(prefix, dce=dce)
30+
31+
def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
32+
"""Lift an expression to a temporary variable if needed."""
33+
if isinstance(expr, ast.Name):
34+
return expr
35+
varname = self.tmpvar(dce=dce, prefix=prefix)
36+
self.add_statement(statement_from_string(f"{varname} = expr", expr=expr))
37+
return create(ast.Name, id=varname, ctx=ast.Load())
38+
39+
40+
def extract_helper_function(helper_fn: object) -> types.FunctionType:
41+
"""Extract the actual function from a Kernel object or return as-is.
42+
43+
This utility function centralizes the logic for handling both regular functions
44+
and Kernel objects that wrap functions.
45+
"""
46+
from ..runtime.kernel import Kernel
47+
48+
# pyre-ignore[16]: We check isinstance before accessing .fn
49+
return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn
50+
51+
52+
class HelperCodegen(CodegenInterface):
53+
"""Codegen wrapper for helper function generation."""
54+
55+
def __init__(self, device_function: DeviceFunction) -> None:
56+
super().__init__(device_function)
57+
58+
def add_statement(self, stmt: ast.AST | str | None) -> None:
59+
if stmt is not None:
60+
if isinstance(stmt, str):
61+
stmt = statement_from_string(stmt)
62+
self.device_function.body.append(stmt)

0 commit comments

Comments
 (0)