Skip to content

Add hl.associative_scan #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object:
),
):
return obj
if isinstance(obj, types.FunctionType):
# Handle functions and Kernel objects
from ..runtime.kernel import Kernel

if isinstance(obj, (types.FunctionType, Kernel)):
from .helper_function import extract_helper_function
from .lift_closures import lift_closures

return lift_closures(obj, origin)
fn = extract_helper_function(obj)
return lift_closures(fn, origin)
if isinstance(obj, ConstExpr):
return obj.value
if isinstance(obj, list):
Expand Down
15 changes: 15 additions & 0 deletions helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

if TYPE_CHECKING:
from ..runtime.config import Config
from .device_ir import HelperFunctionGraphInfo
from .generate_ast import GenerateAST
from .program_id import ProgramIDs

Expand Down Expand Up @@ -185,6 +186,10 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {}

from .helper_function import HelperFunctionManager

self.helper_manager = HelperFunctionManager()

from .indexing_strategy import IndexingStrategy
from .tile_dispatch import TileStrategyDispatch

Expand Down Expand Up @@ -488,6 +493,16 @@ def dead_code_elimination(self) -> None:
if v.name in args_to_remove:
del cache[k]

def register_helper_function(
self, helper_graph_info: HelperFunctionGraphInfo
) -> None:
"""Register a helper function to be generated at global scope."""
self.helper_manager.register_helper_function(helper_graph_info)

def codegen_helper_functions(self) -> list[ast.stmt]:
"""Generate helper function definitions at global scope."""
return self.helper_manager.codegen_helper_functions()

def __enter__(self) -> None:
try:
tls.functions.append(self)
Expand Down
27 changes: 25 additions & 2 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ast
import builtins
from collections.abc import Callable
import contextlib
import dataclasses
import functools
Expand Down Expand Up @@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None:
for graph_id, graph_info in enumerate([*self.graphs]):
assert graph_id == graph_info.graph_id
roller = ReductionRoller(self, rdim, graph_to_info)
new_graph = roller.process(graph_info.graph)
try:
new_graph = roller.process(graph_info.graph)
except NotImplementedError:
first = False
break
new_graph_id = self.add_graph(
new_graph, type(graph_info), **graph_info.kwargs()
)
Expand Down Expand Up @@ -898,6 +901,26 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
return device_ir


@dataclasses.dataclass
class HelperFunctionGraphInfo(NodeArgsGraphInfo):
"""Graph info for helper functions in higher-order operations like associative_scan."""

_param_names: list[str] = dataclasses.field(default_factory=list)

@property
def name(self) -> str:
return f"helper_function_{self.graph_id}"

def find_input_nodes(self) -> list[torch.fx.Node]:
"""Find all placeholder nodes (inputs) in the graph."""
return self.graph.find_nodes(op="placeholder")

def codegen(self, state: CodegenState) -> list[object]:
from .helper_function import codegen_helper_function_graph_info

return codegen_helper_function_graph_info(self, state)


def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
"""
Remove unnecessary tile_index nodes from the graph.
Expand Down
17 changes: 11 additions & 6 deletions helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .ast_extension import statement_from_string
from .compile_environment import CompileEnvironment
from .device_function import DeviceFunction
from .helper_function import CodegenInterface
from .inductor_lowering import CodegenState
from .inductor_lowering import codegen_call_with_graph
from .program_id import ForEachProgramID
Expand All @@ -32,19 +33,25 @@
from .type_propagation import TensorType


class GenerateAST(NodeVisitor):
class GenerateAST(NodeVisitor, CodegenInterface):
def __init__(self, func: HostFunction, config: Config) -> None:
super().__init__()
# Initialize NodeVisitor first
NodeVisitor.__init__(self)

# Initialize our attributes
self.host_function = func
self.host_statements: list[ast.AST] = []
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
self.on_device = False
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
collections.defaultdict(list)
)
self.next_else_block: list[ast.AST] | None = None

# Now create device function and initialize CodegenInterface
self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self)
CodegenInterface.__init__(self, self.device_function)

def offset_var(self, block_idx: int) -> str:
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)

Expand All @@ -63,9 +70,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None:
stmt = statement_from_string(stmt)
self.statements_stack[-1].append(stmt)

def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str:
return self.device_function.unique_name(prefix, dce=dce)

def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name:
if isinstance(expr, ast.Name):
return expr
Expand Down Expand Up @@ -413,6 +417,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
result = ast.Module(
[
*func.codegen_imports(),
*codegen.device_function.codegen_helper_functions(),
*kernel_def,
host_def,
precompile_def,
Expand Down
Loading
Loading