Skip to content

Commit 16e2d6f

Browse files
committed
Add hl.associative_scan
stack-info: PR: #239, branch: jansel/stack/78
1 parent a6d5031 commit 16e2d6f

12 files changed

+1248
-15
lines changed

helion/_compiler/compile_environment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
222222
),
223223
):
224224
return obj
225-
if isinstance(obj, types.FunctionType):
225+
# Handle functions and Kernel objects
226+
from ..runtime.kernel import Kernel
227+
228+
if isinstance(obj, (types.FunctionType, Kernel)):
229+
from .helper_function import extract_helper_function
226230
from .lift_closures import lift_closures
227231

228-
return lift_closures(obj, origin)
232+
fn = extract_helper_function(obj)
233+
return lift_closures(fn, origin)
229234
if isinstance(obj, ConstExpr):
230235
return obj.value
231236
if isinstance(obj, list):

helion/_compiler/device_function.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from ..runtime.config import Config
40+
from .device_ir import HelperFunctionGraphInfo
4041
from .generate_ast import GenerateAST
4142
from .program_id import ProgramIDs
4243

@@ -185,6 +186,8 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
185186
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
186187
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}
187188

189+
self.helper_functions: dict[str, HelperFunctionGraphInfo] = {}
190+
188191
from .indexing_strategy import IndexingStrategy
189192
from .tile_dispatch import TileStrategyDispatch
190193

@@ -488,6 +491,113 @@ def dead_code_elimination(self) -> None:
488491
if v.name in args_to_remove:
489492
del cache[k]
490493

494+
def register_helper_function(
495+
self, helper_graph_info: HelperFunctionGraphInfo
496+
) -> None:
497+
"""Register a helper function to be generated at global scope."""
498+
self.helper_functions[helper_graph_info.name] = helper_graph_info
499+
500+
def codegen_helper_functions(self) -> list[ast.stmt]:
501+
"""Generate helper function definitions at global scope."""
502+
helper_defs = []
503+
for helper_graph_info in self.helper_functions.values():
504+
# Determine the number of parameters from the graph
505+
input_nodes = helper_graph_info.find_input_nodes()
506+
507+
# Generate argument list with consistent names
508+
args = []
509+
param_names = []
510+
for i in range(len(input_nodes)):
511+
arg_name = f"param_{i}"
512+
args.append(create_arg(arg_name))
513+
param_names.append(arg_name)
514+
515+
# Store parameter names for use in body generation
516+
helper_graph_info._param_names = param_names
517+
518+
# Process the FX graph to generate the correct helper function body
519+
func_body = self._codegen_helper_function_body(helper_graph_info)
520+
521+
# Generate the function structure with @triton.jit decorator
522+
func_def = create(
523+
ast.FunctionDef,
524+
name=helper_graph_info.name,
525+
args=create_arguments(args),
526+
body=func_body,
527+
decorator_list=[expr_from_string("triton.jit")],
528+
type_params=[],
529+
)
530+
531+
helper_defs.append(func_def)
532+
533+
return helper_defs
534+
535+
def _codegen_helper_function_body(
536+
self, helper_graph_info: HelperFunctionGraphInfo
537+
) -> list[ast.stmt]:
538+
"""Generate the body of a helper function by processing its FX graph."""
539+
temp_device_function = self._create_temp_device_function(helper_graph_info)
540+
param_args = self._create_parameter_args(helper_graph_info)
541+
542+
with temp_device_function:
543+
results = self._process_helper_graph(
544+
helper_graph_info, temp_device_function, param_args
545+
)
546+
statements = temp_device_function.body.copy()
547+
self._ensure_return_statement(statements, results, helper_graph_info.name)
548+
549+
return cast("list[ast.stmt]", statements)
550+
551+
def _create_temp_device_function(
552+
self, helper_graph_info: HelperFunctionGraphInfo
553+
) -> DeviceFunction:
554+
"""Create a temporary DeviceFunction for helper function generation."""
555+
return DeviceFunction(
556+
name=f"temp_{helper_graph_info.name}",
557+
config=self.config,
558+
codegen=self.codegen,
559+
)
560+
561+
def _create_parameter_args(
562+
self, helper_graph_info: HelperFunctionGraphInfo
563+
) -> list[ast.AST]:
564+
"""Create parameter AST nodes for the helper function."""
565+
param_names = helper_graph_info._param_names
566+
return [expr_from_string(param_name) for param_name in param_names]
567+
568+
def _process_helper_graph(
569+
self,
570+
helper_graph_info: HelperFunctionGraphInfo,
571+
temp_device_function: DeviceFunction,
572+
param_args: list[ast.AST],
573+
) -> object:
574+
"""Process the graph using the existing interpreter infrastructure."""
575+
from .helper_function import HelperCodegen
576+
from .inductor_lowering import GraphInterpreter
577+
578+
helper_codegen = HelperCodegen(temp_device_function)
579+
interpreter = GraphInterpreter(helper_graph_info.graph, helper_codegen)
580+
return interpreter.run(*param_args)
581+
582+
def _ensure_return_statement(
583+
self, statements: list[ast.AST], results: object, function_name: str
584+
) -> None:
585+
"""Ensure the function body has a proper return statement."""
586+
if statements and isinstance(statements[-1], ast.Return):
587+
return
588+
589+
if isinstance(results, ast.AST):
590+
statements.append(create(ast.Return, value=results))
591+
elif isinstance(results, (list, tuple)) and all(
592+
isinstance(r, ast.AST) for r in results
593+
):
594+
tuple_ast = create(ast.Tuple, elts=list(results), ctx=ast.Load())
595+
statements.append(create(ast.Return, value=tuple_ast))
596+
else:
597+
raise RuntimeError(
598+
f"Helper function {function_name} produced invalid result: {type(results)} {results}"
599+
)
600+
491601
def __enter__(self) -> None:
492602
try:
493603
tls.functions.append(self)

helion/_compiler/device_ir.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import ast
44
import builtins
5-
from collections.abc import Callable
65
import contextlib
76
import dataclasses
87
import functools
@@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None:
339338
for graph_id, graph_info in enumerate([*self.graphs]):
340339
assert graph_id == graph_info.graph_id
341340
roller = ReductionRoller(self, rdim, graph_to_info)
342-
new_graph = roller.process(graph_info.graph)
341+
try:
342+
new_graph = roller.process(graph_info.graph)
343+
except NotImplementedError:
344+
first = False
345+
break
343346
new_graph_id = self.add_graph(
344347
new_graph, type(graph_info), **graph_info.kwargs()
345348
)
@@ -807,9 +810,160 @@ def visit_Call(self, node: ast.Call) -> object:
807810
else:
808811
func = self.visit(node.func)
809812

813+
# Special handling for associative_scan
814+
if isinstance(
815+
(func_type_info := node.func._type_info),
816+
CallableType,
817+
) and (
818+
func_type_info.value is hl.associative_scan or func is hl.associative_scan
819+
):
820+
return self._handle_associative_scan(node, args, kwargs)
821+
810822
# pyre-ignore[6]
811823
return _CheckForIndexCalls.retry_call(func, args, kwargs)
812824

825+
def _handle_associative_scan(
826+
self, node: ast.Call, args: list[object], kwargs: dict[str, object]
827+
) -> object:
828+
"""Handle associative_scan calls by tracing the combine function as a subgraph."""
829+
from ..language import _tracing_ops
830+
831+
combine_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cast(
832+
"Callable[[torch.Tensor, torch.Tensor], torch.Tensor]", args[0]
833+
) # The combine function
834+
input_tensor = args[1] # The input tensor
835+
836+
# Extract other arguments from kwargs
837+
dim = kwargs.get("dim", 0)
838+
reverse = kwargs.get("reverse", False)
839+
840+
# Detect if we're dealing with tuple inputs
841+
is_tuple_input = isinstance(input_tensor, (tuple, list))
842+
843+
# Create a subgraph for the combine function
844+
if is_tuple_input:
845+
846+
def run_combine_subgraph(
847+
*args: torch.Tensor,
848+
) -> tuple[torch.Tensor, ...]:
849+
# This will trace the combine function with unpacked tuple inputs
850+
from .helper_function import extract_helper_function
851+
852+
# For tuple inputs, the combine function expects unpacked arguments
853+
# args = [left_val1, left_val2, ..., right_val1, right_val2, ...]
854+
# We need to call: combine_fn(left_val1, left_val2, ..., right_val1, right_val2, ...)
855+
actual_fn = extract_helper_function(combine_fn)
856+
result = actual_fn(*args)
857+
return result if isinstance(result, tuple) else (result,)
858+
else:
859+
860+
def run_combine_subgraph(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
861+
# This will trace the combine function with tensor inputs
862+
from .helper_function import extract_helper_function
863+
864+
actual_fn = extract_helper_function(combine_fn)
865+
return actual_fn(a, b)
866+
867+
# Create fake inputs for the combine function
868+
if is_tuple_input:
869+
# Handle tuple inputs
870+
if isinstance(input_tensor, (tuple, list)) and all(
871+
isinstance(t, torch.Tensor) for t in input_tensor
872+
):
873+
fake_inputs = []
874+
for tensor in input_tensor:
875+
fake_inputs.extend(
876+
[
877+
torch.empty([1], dtype=tensor.dtype, device=tensor.device),
878+
torch.empty([1], dtype=tensor.dtype, device=tensor.device),
879+
]
880+
)
881+
fake_a_and_b = fake_inputs
882+
else:
883+
# Fallback for when input_tensor is a proxy tuple
884+
# Assume 2 tensors of float32 for now
885+
fake_a_and_b = [torch.empty([1], dtype=torch.float32) for _ in range(4)]
886+
else:
887+
# Handle single tensor inputs
888+
if isinstance(input_tensor, torch.Tensor):
889+
fake_a = torch.empty(
890+
[1], dtype=input_tensor.dtype, device=input_tensor.device
891+
)
892+
fake_b = torch.empty(
893+
[1], dtype=input_tensor.dtype, device=input_tensor.device
894+
)
895+
fake_a_and_b = [fake_a, fake_b]
896+
else:
897+
# Fallback for when input_tensor is a proxy
898+
fake_a = torch.empty([1], dtype=torch.float32)
899+
fake_b = torch.empty([1], dtype=torch.float32)
900+
fake_a_and_b = [fake_a, fake_b]
901+
902+
with self.disable_tracing() as tracer:
903+
combine_graph = proxy_tensor.make_fx(
904+
run_combine_subgraph, decomposition_table=select_decomp_table()
905+
)(*fake_a_and_b).graph
906+
907+
combine_graph_id = self.device_ir.add_graph(
908+
combine_graph,
909+
HelperFunctionGraphInfo,
910+
node_args=[], # The combine function doesn't use external args
911+
)
912+
913+
# Create the associative_scan tracing operation
914+
scan_args = (
915+
combine_graph_id,
916+
input_tensor,
917+
dim,
918+
reverse,
919+
is_tuple_input,
920+
)
921+
922+
proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args)
923+
proxy_out = tracer.create_proxy(
924+
"call_function",
925+
_tracing_ops._associative_scan,
926+
proxy_args,
927+
proxy_kwargs,
928+
)
929+
930+
# The output has the same shape as the input
931+
if is_tuple_input:
932+
# For tuple inputs, track each element separately and return a tuple
933+
proxy_tensor.track_tensor_tree(
934+
input_tensor,
935+
proxy_out,
936+
constant=None,
937+
tracer=tracer,
938+
)
939+
# Convert the proxy output to a tuple of individual proxies
940+
tuple_proxies = []
941+
assert isinstance(
942+
input_tensor, (tuple, list)
943+
) # Guaranteed when is_tuple_input is True
944+
for i, tensor in enumerate(input_tensor):
945+
element_proxy = tracer.create_proxy(
946+
"call_function",
947+
operator.getitem,
948+
(proxy_out, i),
949+
{},
950+
)
951+
proxy_tensor.track_tensor_tree(
952+
tensor,
953+
element_proxy,
954+
constant=None,
955+
tracer=tracer,
956+
)
957+
tuple_proxies.append(tensor)
958+
return tuple(tuple_proxies)
959+
proxy_tensor.track_tensor_tree(
960+
input_tensor,
961+
proxy_out,
962+
constant=None,
963+
tracer=tracer,
964+
)
965+
return proxy_out
966+
813967
def visit_Attribute(self, node: ast.Attribute) -> object:
814968
return getattr(self.visit(node.value), node.attr)
815969

@@ -898,6 +1052,37 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
8981052
return device_ir
8991053

9001054

1055+
@dataclasses.dataclass
1056+
class HelperFunctionGraphInfo(NodeArgsGraphInfo):
1057+
"""Graph info for helper functions in higher-order operations like associative_scan."""
1058+
1059+
_param_names: list[str] = dataclasses.field(default_factory=list)
1060+
1061+
@property
1062+
def name(self) -> str:
1063+
return f"helper_function_{self.graph_id}"
1064+
1065+
def find_input_nodes(self) -> list[torch.fx.Node]:
1066+
"""Find all placeholder nodes (inputs) in the graph."""
1067+
return self.graph.find_nodes(op="placeholder")
1068+
1069+
def codegen(self, state: CodegenState) -> list[object]:
1070+
# For helper functions, we need to inline the function body
1071+
# The helper function takes variable arguments and returns their combination
1072+
1073+
# Generate temporary variable names for the helper function arguments
1074+
# Use the graph's input nodes to determine the number of parameters
1075+
input_nodes = self.find_input_nodes()
1076+
args: list[ast.AST] = []
1077+
1078+
for i in range(len(input_nodes)):
1079+
var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}")
1080+
args.append(create(ast.Name, id=var_name, ctx=ast.Load()))
1081+
1082+
# Generate the helper function call
1083+
return codegen_call_with_graph(state.codegen, self.graph, args)
1084+
1085+
9011086
def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
9021087
"""
9031088
Remove unnecessary tile_index nodes from the graph.

0 commit comments

Comments
 (0)