diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 1376b948..11639ac1 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -222,10 +222,15 @@ def to_fake(self, obj: object, origin: Origin) -> object: ), ): return obj - if isinstance(obj, types.FunctionType): + # Handle functions and Kernel objects + from ..runtime.kernel import Kernel + + if isinstance(obj, (types.FunctionType, Kernel)): + from .helper_function import extract_helper_function from .lift_closures import lift_closures - return lift_closures(obj, origin) + fn = extract_helper_function(obj) + return lift_closures(fn, origin) if isinstance(obj, ConstExpr): return obj.value if isinstance(obj, list): diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index 3af3fd90..41a70bfe 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -37,6 +37,7 @@ if TYPE_CHECKING: from ..runtime.config import Config + from .device_ir import HelperFunctionGraphInfo from .generate_ast import GenerateAST from .program_id import ProgramIDs @@ -185,6 +186,10 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None: self.block_size_var_cache: dict[tuple[int, ...], str] = {} self.expr_to_var_info: dict[sympy.Expr, VarInfo] = {} + from .helper_function import HelperFunctionManager + + self.helper_manager = HelperFunctionManager() + from .indexing_strategy import IndexingStrategy from .tile_dispatch import TileStrategyDispatch @@ -488,6 +493,16 @@ def dead_code_elimination(self) -> None: if v.name in args_to_remove: del cache[k] + def register_helper_function( + self, helper_graph_info: HelperFunctionGraphInfo + ) -> None: + """Register a helper function to be generated at global scope.""" + self.helper_manager.register_helper_function(helper_graph_info) + + def codegen_helper_functions(self) -> list[ast.stmt]: + """Generate helper function definitions at global scope.""" + return self.helper_manager.codegen_helper_functions() + def __enter__(self) -> None: try: tls.functions.append(self) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 59f38086..9ab813be 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -2,7 +2,6 @@ import ast import builtins -from collections.abc import Callable import contextlib import dataclasses import functools @@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None: for graph_id, graph_info in enumerate([*self.graphs]): assert graph_id == graph_info.graph_id roller = ReductionRoller(self, rdim, graph_to_info) - new_graph = roller.process(graph_info.graph) + try: + new_graph = roller.process(graph_info.graph) + except NotImplementedError: + first = False + break new_graph_id = self.add_graph( new_graph, type(graph_info), **graph_info.kwargs() ) @@ -898,6 +901,26 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR: return device_ir +@dataclasses.dataclass +class HelperFunctionGraphInfo(NodeArgsGraphInfo): + """Graph info for helper functions in higher-order operations like associative_scan.""" + + _param_names: list[str] = dataclasses.field(default_factory=list) + + @property + def name(self) -> str: + return f"helper_function_{self.graph_id}" + + def find_input_nodes(self) -> list[torch.fx.Node]: + """Find all placeholder nodes (inputs) in the graph.""" + return self.graph.find_nodes(op="placeholder") + + def codegen(self, state: CodegenState) -> list[object]: + from .helper_function import codegen_helper_function_graph_info + + return codegen_helper_function_graph_info(self, state) + + def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None: """ Remove unnecessary tile_index nodes from the graph. diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index 85e0338f..7b9ecf0f 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -17,6 +17,7 @@ from .ast_extension import statement_from_string from .compile_environment import CompileEnvironment from .device_function import DeviceFunction +from .helper_function import CodegenInterface from .inductor_lowering import CodegenState from .inductor_lowering import codegen_call_with_graph from .program_id import ForEachProgramID @@ -32,19 +33,25 @@ from .type_propagation import TensorType -class GenerateAST(NodeVisitor): +class GenerateAST(NodeVisitor, CodegenInterface): def __init__(self, func: HostFunction, config: Config) -> None: - super().__init__() + # Initialize NodeVisitor first + NodeVisitor.__init__(self) + + # Initialize our attributes self.host_function = func self.host_statements: list[ast.AST] = [] self.statements_stack: list[list[ast.AST]] = [self.host_statements] self.on_device = False - self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self) self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = ( collections.defaultdict(list) ) self.next_else_block: list[ast.AST] | None = None + # Now create device function and initialize CodegenInterface + self.device_function = DeviceFunction(f"_{func.name}_kernel", config, self) + CodegenInterface.__init__(self, self.device_function) + def offset_var(self, block_idx: int) -> str: return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx) @@ -63,9 +70,6 @@ def add_statement(self, stmt: ast.AST | str | None) -> None: stmt = statement_from_string(stmt) self.statements_stack[-1].append(stmt) - def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str: - return self.device_function.unique_name(prefix, dce=dce) - def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: if isinstance(expr, ast.Name): return expr @@ -413,6 +417,7 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST: result = ast.Module( [ *func.codegen_imports(), + *codegen.device_function.codegen_helper_functions(), *kernel_def, host_def, precompile_def, diff --git a/helion/_compiler/helper_function.py b/helion/_compiler/helper_function.py new file mode 100644 index 00000000..22a1accc --- /dev/null +++ b/helion/_compiler/helper_function.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +import ast +import inspect +from typing import TYPE_CHECKING +from typing import Callable +from typing import Literal +from typing import cast + +import torch + +from .ast_extension import create +from .ast_extension import create_arg +from .ast_extension import create_arguments +from .ast_extension import expr_from_string +from .ast_extension import statement_from_string + +if TYPE_CHECKING: + import types + + from .device_function import DeviceFunction + from .device_ir import HelperFunctionGraphInfo + + +class CodegenInterface(ABC): + """Abstract base class for codegen interfaces used by GraphInterpreter.""" + + def __init__(self, device_function: DeviceFunction) -> None: + self.device_function = device_function + + @abstractmethod + def add_statement(self, stmt: ast.AST | str | None) -> None: + """Add a statement to the generated code.""" + + def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str: + """Generate a temporary variable name.""" + return self.device_function.unique_name(prefix, dce=dce) + + def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: + """Lift an expression to a temporary variable if needed.""" + if isinstance(expr, ast.Name): + return expr + varname = self.tmpvar(dce=dce, prefix=prefix) + self.add_statement(statement_from_string(f"{varname} = expr", expr=expr)) + return create(ast.Name, id=varname, ctx=ast.Load()) + + +def extract_helper_function(helper_fn: object) -> types.FunctionType: + """Extract the actual function from a Kernel object or return as-is. + + This utility function centralizes the logic for handling both regular functions + and Kernel objects that wrap functions. + """ + from ..runtime.kernel import Kernel + + # pyre-ignore[16]: We check isinstance before accessing .fn + return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn + + +CombineFunctionBasic = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] +CombineFunctionTuple = Callable[ + [tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...] +] +CombineFunctionUnpacked = Callable[[torch.Tensor, ...], tuple[torch.Tensor, ...]] +CombineFunction = CombineFunctionBasic | CombineFunctionTuple | CombineFunctionUnpacked + + +def create_combine_function_wrapper( + combine_fn: CombineFunction, + *, + is_tuple_input: bool, + target_format: Literal["tuple", "unpacked"], +) -> CombineFunction: + """ + Create a wrapper around combine_fn that converts between different combine function formats. + + Args: + combine_fn: The original combine function + is_tuple_input: Whether the input is a tuple + target_format: Either 'tuple' or 'unpacked' format + - 'tuple': expects (left_tuple, right_tuple) for tuple inputs + - 'unpacked': expects (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) for tuple inputs + + Returns: + A wrapper function that converts between the formats + """ + # Extract the actual function (handles @helion.kernel decorated functions) + actual_fn = extract_helper_function(combine_fn) + + # For single tensor inputs, no conversion needed + if not is_tuple_input: + return actual_fn + + # Inspect the original function signature to determine its format + sig = inspect.signature(actual_fn) + param_count = len(sig.parameters) + + # Determine the original format based on parameter count + # If it has 2 parameters, it's tuple format: (left_tuple, right_tuple) + # If it has 4+ parameters, it's unpacked format: (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) + original_format = "tuple" if param_count < 4 else "unpacked" + + # If the original format matches target format, no conversion needed + if target_format == original_format: + return actual_fn + + # Create conversion wrapper + if target_format == "tuple" and original_format == "unpacked": + # Convert from unpacked to tuple format + # Target: (left_tuple, right_tuple) + # Original: (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) + def tuple_wrapper( + left_tuple: tuple[torch.Tensor, ...], right_tuple: tuple[torch.Tensor, ...] + ) -> tuple[torch.Tensor, ...]: + # pyre-ignore[6] + return inner_unpacked(*left_tuple, *right_tuple) + + inner_unpacked: CombineFunctionUnpacked = cast( + "CombineFunctionUnpacked", actual_fn + ) + return tuple_wrapper + + if target_format == "unpacked" and original_format == "tuple": + # Convert from tuple to unpacked format + # Target: (left_elem0, left_elem1, ..., right_elem0, right_elem1, ...) + # Original: (left_tuple, right_tuple) + def unpacked_wrapper(*args: torch.Tensor) -> tuple[torch.Tensor, ...]: + num_args = len(args) + assert (num_args % 2) == 0 + half = num_args // 2 + left_tuple = args[:half] + right_tuple = args[half:] + return inner_tuple((*left_tuple,), (*right_tuple,)) + + inner_tuple: CombineFunctionTuple = cast("CombineFunctionTuple", actual_fn) + return unpacked_wrapper + + # Should not reach here + raise ValueError( + f"Unsupported conversion from {original_format} to {target_format}" + ) + + +class HelperCodegen(CodegenInterface): + """Codegen wrapper for helper function generation.""" + + def __init__(self, device_function: DeviceFunction) -> None: + super().__init__(device_function) + + def add_statement(self, stmt: ast.AST | str | None) -> None: + if stmt is not None: + if isinstance(stmt, str): + stmt = statement_from_string(stmt) + self.device_function.body.append(stmt) + + +class HelperFunctionManager: + """Manages helper function registration and code generation.""" + + def __init__(self) -> None: + self.helper_functions: dict[str, HelperFunctionGraphInfo] = {} + + def register_helper_function( + self, helper_graph_info: HelperFunctionGraphInfo + ) -> None: + """Register a helper function to be generated at global scope.""" + self.helper_functions[helper_graph_info.name] = helper_graph_info + + def codegen_helper_functions(self) -> list[ast.stmt]: + """Generate helper function definitions at global scope.""" + helper_defs = [] + for helper_graph_info in self.helper_functions.values(): + # Determine the number of parameters from the graph + input_nodes = helper_graph_info.find_input_nodes() + + # Generate argument list with consistent names + args = [] + param_names = [] + for i in range(len(input_nodes)): + arg_name = f"param_{i}" + args.append(create_arg(arg_name)) + param_names.append(arg_name) + + # Store parameter names for use in body generation + helper_graph_info._param_names = param_names + + # Process the FX graph to generate the correct helper function body + func_body = self._codegen_helper_function_body(helper_graph_info) + + # Generate the function structure with @triton.jit decorator + func_def = create( + ast.FunctionDef, + name=helper_graph_info.name, + args=create_arguments(args), + body=func_body, + decorator_list=[expr_from_string("triton.jit")], + type_params=[], + ) + + helper_defs.append(func_def) + + return helper_defs + + def _codegen_helper_function_body( + self, helper_graph_info: HelperFunctionGraphInfo + ) -> list[ast.stmt]: + """Generate the body of a helper function by processing its FX graph.""" + temp_device_function = self._create_temp_device_function(helper_graph_info) + param_args = self._create_parameter_args(helper_graph_info) + + with temp_device_function: + results = self._process_helper_graph( + helper_graph_info, temp_device_function, param_args + ) + statements = temp_device_function.body.copy() + self._ensure_return_statement(statements, results, helper_graph_info.name) + + return cast("list[ast.stmt]", statements) + + def _create_temp_device_function( + self, helper_graph_info: HelperFunctionGraphInfo + ) -> DeviceFunction: + """Create a temporary DeviceFunction for helper function generation.""" + # Import here to avoid circular imports + from .device_function import DeviceFunction + + current = DeviceFunction.current() + + return DeviceFunction( + name=f"temp_{helper_graph_info.name}", + config=current.config, + codegen=current.codegen, + ) + + def _create_parameter_args( + self, helper_graph_info: HelperFunctionGraphInfo + ) -> list[ast.AST]: + """Create parameter AST nodes for the helper function.""" + param_names = helper_graph_info._param_names + return [expr_from_string(param_name) for param_name in param_names] + + def _process_helper_graph( + self, + helper_graph_info: HelperFunctionGraphInfo, + temp_device_function: DeviceFunction, + param_args: list[ast.AST], + ) -> object: + """Process the graph using the existing interpreter infrastructure.""" + from .inductor_lowering import GraphInterpreter + + helper_codegen = HelperCodegen(temp_device_function) + interpreter = GraphInterpreter(helper_graph_info.graph, helper_codegen) + return interpreter.run(*param_args) + + def _ensure_return_statement( + self, statements: list[ast.AST], results: object, function_name: str + ) -> None: + """Ensure the function body has a proper return statement.""" + if statements and isinstance(statements[-1], ast.Return): + return + + if isinstance(results, ast.AST): + statements.append(create(ast.Return, value=results)) + elif isinstance(results, (list, tuple)) and all( + isinstance(r, ast.AST) for r in results + ): + tuple_ast = create(ast.Tuple, elts=list(results), ctx=ast.Load()) + statements.append(create(ast.Return, value=tuple_ast)) + else: + raise RuntimeError( + f"Helper function {function_name} produced invalid result: {type(results)} {results}" + ) + + +def codegen_helper_function_graph_info( + helper_graph_info: HelperFunctionGraphInfo, state: object +) -> list[object]: + """Generate code for HelperFunctionGraphInfo objects.""" + from .inductor_lowering import CodegenState + from .inductor_lowering import codegen_call_with_graph + + if not isinstance(state, CodegenState): + raise TypeError(f"Expected CodegenState, got {type(state)}") + + # For helper functions, we need to inline the function body + # The helper function takes variable arguments and returns their combination + + # Generate temporary variable names for the helper function arguments + # Use the graph's input nodes to determine the number of parameters + input_nodes = helper_graph_info.find_input_nodes() + args: list[ast.AST] = [] + + for i in range(len(input_nodes)): + var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}") + args.append(create(ast.Name, id=var_name, ctx=ast.Load())) + + # Generate the helper function call + return codegen_call_with_graph(state.codegen, helper_graph_info.graph, args) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index d968f657..5f44d797 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -62,6 +62,7 @@ from .. import Config from .device_function import DeviceFunction from .generate_ast import GenerateAST + from .helper_function import CodegenInterface from .tile_dispatch import TileStrategyDispatch CodegenHandler = Callable[["GraphInterpreter", torch.fx.Node], object] @@ -374,13 +375,14 @@ def install_kernel_handlers( self, ctx: GraphInterpreter, node: torch.fx.Node ) -> ContextManager[None]: return install_inductor_kernel_handlers( - ctx.cg, dict(zip(self.input_names, self.input_asts(ctx, node), strict=True)) + ctx.cg, + dict(zip(self.input_names, self.input_asts(ctx, node), strict=True)), ) @contextlib.contextmanager def install_inductor_kernel_handlers( - cg: GenerateAST, args: dict[str, ast.AST] + cg: CodegenInterface, args: dict[str, ast.AST] ) -> Iterator[None]: with ( inductor_config.patch( @@ -480,6 +482,12 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: self.buffer.data.inner_fn(indices, reduction_indices) ) + from .. import exc + from .generate_ast import GenerateAST + + if not isinstance(ctx.cg, GenerateAST): + raise exc.NotAllowedInHelperFunction + state = CodegenState( ctx.cg, fx_node=node, @@ -544,6 +552,12 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object: proxy_args = [*map_arg(node.args, lambda arg: arg.meta["val"])] assert self.api_func._codegen is not None + from .. import exc + from .generate_ast import GenerateAST + + if not isinstance(ctx.cg, GenerateAST): + raise exc.NotAllowedInHelperFunction + return self.api_func._codegen( CodegenState( ctx.cg, @@ -894,7 +908,9 @@ def codegen_baddbmm(ctx: GraphInterpreter, node: torch.fx.Node) -> ast.AST: class GenerateASTFromInductor(DefaultHandler): - def __init__(self, cg: GenerateAST, input_name_lookup: dict[str, ast.AST]) -> None: + def __init__( + self, cg: CodegenInterface, input_name_lookup: dict[str, ast.AST] + ) -> None: super().__init__() self.parent_handler = TritonOverrides() self.cg = cg @@ -935,7 +951,7 @@ def _unpack_opsvalue(value: object) -> str: class GraphInterpreter(Interpreter): - def __init__(self, graph: torch.fx.Graph, cg: GenerateAST) -> None: + def __init__(self, graph: torch.fx.Graph, cg: CodegenInterface) -> None: super().__init__(_LazyGraphModule({}, graph), garbage_collect_values=False) self.cg = cg diff --git a/helion/exc.py b/helion/exc.py index 51826af9..0749fe74 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -322,3 +322,7 @@ class UnsupportedPythonType(BaseError): class TypeInferenceError(BaseError): message = "{0}" + + +class NotAllowedInHelperFunction(BaseError): + message = "This operation is not allowed inside helper functions. It requires kernel context." diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 2bbcdb1f..d0dceeb9 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -11,6 +11,9 @@ from .memory_ops import atomic_add as atomic_add from .memory_ops import load as load from .memory_ops import store as store +from .scan_ops import associative_scan as associative_scan +from .scan_ops import cumprod as cumprod +from .scan_ops import cumsum as cumsum from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size from .tile_ops import tile_end as tile_end diff --git a/helion/language/_decorators.py b/helion/language/_decorators.py index dbf6727a..aa6b2bb9 100644 --- a/helion/language/_decorators.py +++ b/helion/language/_decorators.py @@ -76,6 +76,7 @@ class APIFunc(Protocol): _fake_fn: Callable[..., object] | None _prepare_args: Callable[[tuple[object, ...]], tuple[object, ...]] _get_masked_value: Callable[[torch.fx.Node], float | bool | None] | None + _to_device_ir: Callable[..., object] | None _signature: inspect.Signature def __call__(self, *args: object, **kwargs: object) -> object: ... @@ -150,16 +151,20 @@ def wrapper(*args: object, **kwargs: object) -> object: # We hit type errors if we use the regular custom_op overload, instead we # intercept the call and fake the custom op. with proxy_tensor.disable_proxy_modes_tracing(): - proxy_out = tracer.create_proxy( - "call_function", - wrapper, - *args_to_proxies(tracer, flat_args, {}), - ) - assert api._fake_fn is not None - out = api._fake_fn(*flat_args) - proxy_tensor.track_tensor_tree( - out, proxy_out, constant=None, tracer=tracer - ) + # Use _to_device_ir if available, otherwise use _fake_fn with proxy creation + if api._to_device_ir is not None: + out = api._to_device_ir(tracer, *flat_args) + else: + proxy_out = tracer.create_proxy( + "call_function", + wrapper, + *args_to_proxies(tracer, flat_args, {}), + ) + assert api._fake_fn is not None + out = api._fake_fn(*flat_args) + proxy_tensor.track_tensor_tree( + out, proxy_out, constant=None, tracer=tracer + ) return out api: APIFunc = cast("APIFunc", wrapper) @@ -176,6 +181,7 @@ def wrapper(*args: object, **kwargs: object) -> object: api._codegen = None api._fake_fn = None api._get_masked_value = None + api._to_device_ir = None api._signature = signature or inspect.signature( cast("Callable[..., object]", fn) ) @@ -267,6 +273,20 @@ def _impl( return _impl +def register_to_device_ir( + original_fn: Callable[..., object], +) -> _NoReturnDecorator[object]: + def _impl(to_device_ir_fn: Callable[..., object]) -> Callable[..., Never]: + assert is_api_func(original_fn), ( + f"{register_to_device_ir.__qualname__} can only be used on API functions" + ) + assert original_fn._to_device_ir is None + original_fn._to_device_ir = to_device_ir_fn + return _no_call + + return _impl + + def _default_type_function( fake_fn: Callable[..., object], tiles_as_sizes: bool ) -> Callable[..., TypeInfo]: @@ -292,19 +312,17 @@ def _to_proxy(arg: TypeInfo) -> object: # Tracks 1-1 mapping between Python functions and their Helion API counterparts within device function. -_DEVICE_FUNC_REPLACEMENTS: dict[object, APIFunc] = {} +_DEVICE_FUNC_REPLACEMENTS: dict[object, Callable[..., object]] = {} def device_func_replacement(python_func: object) -> _Decorator: def _impl(fn: _C) -> _C: - assert is_api_func(fn), ( - f"{device_func_replacement.__qualname__} can only be used on API functions" - ) + assert callable(fn) _DEVICE_FUNC_REPLACEMENTS[python_func] = fn - return fn # pyre-ignore[7] + return fn return _impl -def get_device_func_replacement(func: object) -> APIFunc | None: +def get_device_func_replacement(func: object) -> Callable[..., object] | None: return _DEVICE_FUNC_REPLACEMENTS.get(func) diff --git a/helion/language/scan_ops.py b/helion/language/scan_ops.py new file mode 100644 index 00000000..b5af30c4 --- /dev/null +++ b/helion/language/scan_ops.py @@ -0,0 +1,349 @@ +from __future__ import annotations + +import ast +import operator +from typing import TYPE_CHECKING +from typing import cast +from typing import overload + +import torch +import torch._higher_order_ops as higher_order_ops +from torch.fx.experimental import proxy_tensor + +from .. import exc +from . import _decorators + +if TYPE_CHECKING: + from .._compiler.helper_function import CombineFunction + from .._compiler.inductor_lowering import CodegenState + from .._compiler.type_propagation import Origin + from .._compiler.type_propagation import TypeInfo + + +__all__ = ["associative_scan", "cumprod", "cumsum"] + + +@overload +@_decorators.device_func_replacement(higher_order_ops.associative_scan) +@_decorators.api(is_device_only=True) +def associative_scan( + combine_fn: CombineFunction, + input_tensor: torch.Tensor, + dim: int, + reverse: bool = False, +) -> torch.Tensor: ... + + +@overload +@_decorators.device_func_replacement(higher_order_ops.associative_scan) +@_decorators.api(is_device_only=True) +def associative_scan( + combine_fn: CombineFunction, + input_tensor: tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> tuple[torch.Tensor, ...]: ... + + +@_decorators.device_func_replacement(higher_order_ops.associative_scan) +@_decorators.api(is_device_only=True) +def associative_scan( + combine_fn: CombineFunction, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Applies an associative scan operation along a specified dimension. + + Args: + combine_fn: A binary function that combines two elements element-wise. + Can be tensor->tensor or tuple->tuple function. + input_tensor: Input tensor or tuple of tensors to scan. + dim: The dimension along which to scan. + reverse: If True, performs the scan in reverse order. + + Returns: + A tensor or tuple of tensors with the same shape as input containing the scan result. + """ + raise exc.NotInsideKernel + + +@_decorators.register_fake(associative_scan) +def _( + combine_fn: CombineFunction, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Fake implementation that returns fake tensors with the same shape as input.""" + if isinstance(input_tensor, (tuple, list)): + return tuple(torch.empty_like(t) for t in input_tensor) + return torch.empty_like(input_tensor) + + +@_decorators.register_to_device_ir(associative_scan) +def _( + tracer: proxy_tensor.PythonKeyTracer, + combine_fn: CombineFunction, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """ + Device IR implementation that handles tracing for associative_scan. We map + associative_scan to _associative_scan, whith a pre-traced graph for the combine + function. + """ + from .._compiler.device_ir import DeviceIR + from .._compiler.device_ir import HelperFunctionGraphInfo + from .._compiler.device_ir import args_to_proxies + from .._compiler.device_ir import select_decomp_table + from .._compiler.helper_function import create_combine_function_wrapper + + is_tuple_input = isinstance(input_tensor, (tuple, list)) + if is_tuple_input: + assert all(isinstance(t, torch.Tensor) for t in input_tensor), ( + "associative_scan input must be a tuple of tensors" + ) + else: + assert isinstance(input_tensor, torch.Tensor), ( + "associative_scan input must be a tensor" + ) + assert isinstance(dim, int), "associative_scan dim must be an integer" + + assert callable(combine_fn), "combine_fn must be callable" + combine_fn = create_combine_function_wrapper( + combine_fn, is_tuple_input=is_tuple_input, target_format="unpacked" + ) + + # Create fake inputs for the combine function + fake_inputs = [] + for tensor in input_tensor if is_tuple_input else [input_tensor]: + fake_inputs.extend( + [ + torch.empty([1], dtype=tensor.dtype, device=tensor.device), + torch.empty([1], dtype=tensor.dtype, device=tensor.device), + ] + ) + + combine_graph = proxy_tensor.make_fx( + combine_fn, decomposition_table=select_decomp_table() + )(*fake_inputs).graph + combine_graph_id = DeviceIR.current().add_graph( + combine_graph, + HelperFunctionGraphInfo, + node_args=[], + ) + + # Create the associative_scan tracing operation + scan_args = (combine_graph_id, input_tensor, dim, reverse, is_tuple_input) + proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args) + proxy_out = tracer.create_proxy( + "call_function", + _associative_scan, + proxy_args, + proxy_kwargs, + ) + + # The output has the same shape as the input + if is_tuple_input: + proxy_tensor.track_tensor_tree( + input_tensor, proxy_out, constant=None, tracer=tracer + ) + tuple_proxies = [] + assert isinstance(input_tensor, (tuple, list)) + for i, tensor in enumerate(input_tensor): + element_proxy = tracer.create_proxy( + "call_function", + operator.getitem, + (proxy_out, i), + {}, + ) + proxy_tensor.track_tensor_tree( + tensor, element_proxy, constant=None, tracer=tracer + ) + tuple_proxies.append(tensor) + return tuple(tuple_proxies) + + proxy_tensor.track_tensor_tree( + input_tensor, proxy_out, constant=None, tracer=tracer + ) + return input_tensor + + +@_decorators.type_propagation(associative_scan) +def _( + combine_fn: TypeInfo, + input_tensor: TypeInfo, + dim: TypeInfo, + reverse: TypeInfo | None = None, + *, + origin: Origin, +) -> TypeInfo: + """Type propagation for associative_scan - output has same type as input.""" + from .._compiler.type_propagation import CallableType + from .._compiler.type_propagation import SequenceType + from .._compiler.type_propagation import TensorType + + # Validate that combine_fn is callable + if not isinstance(combine_fn, CallableType): + raise exc.TypeInferenceError(f"combine_fn must be callable, got {combine_fn}") + + # Validate that input_tensor is a tensor or tuple of tensors + if isinstance(input_tensor, TensorType): + # Single tensor case + return input_tensor + if isinstance(input_tensor, SequenceType): + # Tuple of tensors case - validate all elements are tensors + for elem_type in input_tensor.unpack(): + if not isinstance(elem_type, TensorType): + raise exc.TypeInferenceError( + f"All elements in tuple must be tensors, got {elem_type}" + ) + # Return the same tuple type + return input_tensor + raise exc.TypeInferenceError( + f"input_tensor must be a tensor or tuple of tensors, got {input_tensor}" + ) + + +@_decorators.device_func_replacement(torch.cumsum) +def cumsum(input_tensor: torch.Tensor, dim: int, reverse: bool = False) -> torch.Tensor: + """ + Compute the cumulative sum along a specified dimension. + + Args: + input_tensor: Input tensor to compute cumulative sum. + dim: The dimension along which to compute cumulative sum. + reverse: If True, performs the cumsum in reverse order. + + Returns: + A tensor with the same shape as input containing the cumulative sum. + """ + return associative_scan(torch.add, input_tensor, dim, reverse) + + +@_decorators.device_func_replacement(torch.cumprod) +def cumprod( + input_tensor: torch.Tensor, dim: int, reverse: bool = False +) -> torch.Tensor: + """ + Compute the cumulative product along a specified dimension. + + Args: + input_tensor: Input tensor to compute cumulative product. + dim: The dimension along which to compute cumulative product. + reverse: If True, performs the cumprod in reverse order. + + Returns: + A tensor with the same shape as input containing the cumulative product. + """ + return associative_scan(torch.mul, input_tensor, dim, reverse) + + +@_decorators.api() +def _associative_scan( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Device IR implementation of associative scan, not meant to be called directly.""" + raise AssertionError("this should never be called") + + +@_decorators.register_fake(_associative_scan) +def _( + combine_graph_id: int, + input_tensor: torch.Tensor | tuple[torch.Tensor, ...], + dim: int, + reverse: bool = False, + is_tuple_input: bool = False, +) -> torch.Tensor | tuple[torch.Tensor, ...]: + """Fake implementation that returns a tensor/tuple with the same shape as input.""" + if is_tuple_input: + assert isinstance(input_tensor, (tuple, list)), input_tensor + return tuple(torch.empty_like(t) for t in input_tensor) + assert isinstance(input_tensor, torch.Tensor), input_tensor + return torch.empty_like(input_tensor) + + +@_decorators.codegen(_associative_scan) +def _(state: CodegenState) -> ast.AST | list[ast.AST]: + """Generate code for associative scan with combine function.""" + + combine_graph_id = state.proxy_arg(0) + dim = state.proxy_arg(2) + reverse = state.proxy_arg(3) + is_tuple_input = state.proxy_arg(4) + + input_tensor = _get_input_tensor_ast(state, bool(is_tuple_input)) + helper_func_name = _register_helper_function(state, cast("int", combine_graph_id)) + scan_expr = _create_scan_expression( + input_tensor, cast("int", dim), helper_func_name, bool(reverse) + ) + + if is_tuple_input: + return _create_tuple_result_expressions(state, scan_expr) + return scan_expr + + +def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST: + """Get the input tensor AST, handling tuple inputs specially.""" + if not is_tuple_input: + return state.ast_arg(1) + + raw_input = state.ast_args[1] + if isinstance(raw_input, tuple): + from .._compiler.ast_extension import create + + tuple_elts = [ + elt if isinstance(elt, ast.AST) else ast.Constant(value=elt) + for elt in raw_input + ] + return create(ast.Tuple, elts=tuple_elts, ctx=ast.Load()) + return state.ast_arg(1) + + +def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: + """Register the helper function and return its name.""" + from .._compiler.host_function import HostFunction + + helper_graph_info = HostFunction.current().device_ir.graphs[combine_graph_id] + state.codegen.device_function.register_helper_function(helper_graph_info) + return helper_graph_info.name + + +def _create_scan_expression( + input_tensor: ast.AST, dim: int, helper_func_name: str, reverse: bool +) -> ast.AST: + """Create the tl.associative_scan expression.""" + from .._compiler.ast_extension import expr_from_string + + template = ( + f"tl.associative_scan(input_tensor, dim_value, {helper_func_name}, reverse=True)" + if reverse + else f"tl.associative_scan(input_tensor, dim_value, {helper_func_name})" + ) + return expr_from_string( + template, + input_tensor=input_tensor, + dim_value=ast.Constant(value=dim), + ) + + +def _create_tuple_result_expressions( + state: CodegenState, scan_expr: ast.AST +) -> list[ast.AST]: + """Create getitem expressions for tuple results.""" + from .._compiler.ast_extension import expr_from_string + + raw_input = state.ast_args[1] + num_elements = len(raw_input) if isinstance(raw_input, tuple) else 2 + + return [ + expr_from_string(f"scan_result[{i}]", scan_result=scan_expr) + for i in range(num_elements) + ] diff --git a/test/test_associative_scan.py b/test/test_associative_scan.py new file mode 100644 index 00000000..e4b16bb7 --- /dev/null +++ b/test/test_associative_scan.py @@ -0,0 +1,862 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import code_and_output +import helion.language as hl + + +def add_combine_fn(x, y): + """Simple addition combine function for prefix sum.""" + return x + y + + +def max_combine_fn(x, y): + """Maximum combine function for prefix maximum.""" + return torch.maximum(x, y) + + +def mul_combine_fn(x, y): + """Multiplication combine function for prefix product.""" + return x * y + + +def min_combine_fn(x, y): + """Minimum combine function for prefix minimum.""" + return torch.minimum(x, y) + + +def helion_combine_fn(left_values, left_indices, right_values, right_indices): + """Tuple combine function with unpacked arguments (matching GitHub issue example).""" + # Segmented scan: if indices are the same, add values; otherwise, take right values (reset) + same_segment = left_indices == right_indices + combined_values = torch.where( + same_segment, left_values + right_values, right_values + ) + combined_indices = right_indices # Always propagate the right index + return combined_values, combined_indices + + +def segmented_combine_fn(left_values, left_indices, right_values, right_indices): + """Segmented scan: reset accumulation when segment changes.""" + same_segment = left_indices == right_indices + combined_values = torch.where( + same_segment, left_values + right_values, right_values + ) + combined_indices = right_indices # Always propagate the right index + return combined_values, combined_indices + + +def argmax_combine_fn(left_values, left_indices, right_values, right_indices): + """Cumulative argmax: keep the value and index of the maximum element seen so far.""" + # If right value is greater, take right value and index; otherwise keep left + take_right = right_values > left_values + combined_values = torch.where(take_right, right_values, left_values) + combined_indices = torch.where(take_right, right_indices, left_indices) + return combined_values, combined_indices + + +def cumsum_helper(x: torch.Tensor) -> torch.Tensor: + """Helper function that performs cumulative sum using hl.associative_scan.""" + return hl.associative_scan(add_combine_fn, x, dim=0) + + +@helion.jit +def jit_add_combine_fn(x, y): + """Addition combine function with @helion.jit decorator (should be ignored).""" + return x + y + + +class TestAssociativeScan(unittest.TestCase): + def test_associative_scan_basic_addition(self): + """Test basic associative_scan functionality with prefix sum.""" + + @helion.kernel(use_default_config=True) + def test_scan_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + return result + + # Create test input + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]], + device=DEVICE, + ) + + # Test that the kernel compiles and runs + code, result = code_and_output(test_scan_kernel, (x,)) + + # Test the actual scan operation + expected = torch.tensor( + [[1.0, 3.0, 6.0, 10.0], [5.0, 11.0, 18.0, 26.0], [9.0, 19.0, 30.0, 42.0]], + device=DEVICE, + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains the correct helper function + self.assertIn("def helper_function_", code) + self.assertIn("param_0 + param_1", code) + self.assertIn("tl.associative_scan", code) + + def test_associative_scan_maximum(self): + """Test associative_scan with maximum combine function.""" + + @helion.kernel(use_default_config=True) + def test_max_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(max_combine_fn, row_data, dim=1) + return result + + # Test input with decreasing and increasing values + x = torch.tensor( + [[1.0, 5.0, 2.0, 8.0, 3.0], [7.0, 1.0, 9.0, 2.0, 4.0]], + device=DEVICE, + ) + + code, result = code_and_output(test_max_kernel, (x,)) + + # Expected prefix maximum + expected = torch.tensor( + [[1.0, 5.0, 5.0, 8.0, 8.0], [7.0, 7.0, 9.0, 9.0, 9.0]], + device=DEVICE, + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains maximum operation (either tl.maximum or triton_helpers.maximum) + self.assertTrue("tl.maximum" in code or "triton_helpers.maximum" in code) + + def test_associative_scan_multiplication(self): + """Test associative_scan with multiplication combine function.""" + + @helion.kernel(use_default_config=True) + def test_mul_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(mul_combine_fn, row_data, dim=1) + return result + + # Test input for prefix product + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [2.0, 0.5, 3.0, 2.0]], + device=DEVICE, + ) + + code, result = code_and_output(test_mul_kernel, (x,)) + + # Expected prefix product + expected = torch.tensor( + [[1.0, 2.0, 6.0, 24.0], [2.0, 1.0, 3.0, 6.0]], + device=DEVICE, + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains multiplication + self.assertIn("param_0 * param_1", code) + + def test_associative_scan_minimum(self): + """Test associative_scan with minimum combine function.""" + + @helion.kernel(use_default_config=True) + def test_min_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(min_combine_fn, row_data, dim=1) + return result + + # Test input with various values + x = torch.tensor( + [[5.0, 2.0, 8.0, 1.0, 6.0], [3.0, 7.0, 1.0, 9.0, 2.0]], + device=DEVICE, + ) + + code, result = code_and_output(test_min_kernel, (x,)) + + # Expected prefix minimum + expected = torch.tensor( + [[5.0, 2.0, 2.0, 1.0, 1.0], [3.0, 3.0, 1.0, 1.0, 1.0]], + device=DEVICE, + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains minimum operation (either tl.minimum or triton_helpers.minimum) + self.assertTrue("tl.minimum" in code or "triton_helpers.minimum" in code) + + def test_associative_scan_multiple_functions(self): + """Test using multiple different combine functions in one kernel.""" + + @helion.kernel(use_default_config=True) + def test_multi_kernel(x: torch.Tensor) -> torch.Tensor: + sum_result = torch.empty_like(x) + max_result = torch.empty_like(x) + + for i in hl.tile(x.size(0)): + row_data = x[i, :] + # Prefix sum + sum_result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + # Prefix maximum + max_result[i, :] = hl.associative_scan(max_combine_fn, row_data, dim=1) + + # Return sum for testing + return sum_result + + x = torch.tensor([[1.0, 3.0, 2.0, 4.0]], device=DEVICE) + + code, result = code_and_output(test_multi_kernel, (x,)) + + # Test the sum result + expected_sum = torch.tensor([[1.0, 4.0, 6.0, 10.0]], device=DEVICE) + torch.testing.assert_close(result, expected_sum) + + # Verify multiple helper functions are generated + self.assertIn("helper_function_0", code) + self.assertIn("helper_function_1", code) + self.assertIn("param_0 + param_1", code) + # Check for maximum operation (either format) + self.assertTrue("tl.maximum" in code or "triton_helpers.maximum" in code) + + def test_associative_scan_type_propagation(self): + """Test that associative_scan type propagation works correctly.""" + + @helion.kernel(use_default_config=True) + def test_type_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + return result + + x = torch.randn(16, 1024, device=DEVICE, dtype=torch.float32) + code, result = code_and_output(test_type_kernel, (x,)) + + # Verify the output has the same type and shape as input + self.assertEqual(result.dtype, x.dtype) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.device, x.device) + # Verify it produces the correct cumulative sum + expected = torch.cumsum(x, dim=1) + # Use relaxed tolerance for large tensors due to accumulated floating-point errors + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_associative_scan_different_dtypes(self): + """Test associative_scan with different data types.""" + + for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: + with self.subTest(dtype=dtype): + + @helion.kernel(use_default_config=True) + def test_dtype_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan( + add_combine_fn, row_data, dim=1 + ) + return result + + # Use integer values for all dtypes to avoid precision issues + x_vals = [[1, 2, 3, 4], [5, 6, 7, 8]] + x = torch.tensor(x_vals, device=DEVICE, dtype=dtype) + + code, result = code_and_output(test_dtype_kernel, (x,)) + + # Verify output dtype matches input + self.assertEqual(result.dtype, x.dtype) + + # Check correctness for numeric types + if dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: + expected = torch.cumsum(x, dim=1) + # Convert expected to match result dtype if needed + if expected.dtype != result.dtype: + expected = expected.to(result.dtype) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_associative_scan_different_sizes(self): + """Test associative_scan with different tensor sizes.""" + + test_shapes = [ + (1, 4), # Single row + (3, 8), # Multiple rows + (5, 16), # Medium size + (2, 1), # Single column + (4, 1024), # Large size + (8, 1024), # Multiple large rows + ] + + for shape in test_shapes: + with self.subTest(shape=shape): + + @helion.kernel(use_default_config=True) + def test_size_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan( + add_combine_fn, row_data, dim=1 + ) + return result + + x = torch.randn(shape, device=DEVICE) + code, result = code_and_output(test_size_kernel, (x,)) + + # Verify output shape matches input + self.assertEqual(result.shape, x.shape) + + # Verify correctness + expected = torch.cumsum(x, dim=1) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_associative_scan_reverse(self): + """Test associative_scan with reverse=True parameter.""" + + @helion.kernel(use_default_config=True) + def test_reverse_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan( + add_combine_fn, row_data, dim=1, reverse=True + ) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=DEVICE) + + code, result = code_and_output(test_reverse_kernel, (x,)) + + # For reverse prefix sum: [10, 9, 7, 4] (sum from right to left) + expected = torch.tensor([[10.0, 9.0, 7.0, 4.0]], device=DEVICE) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify reverse parameter is in generated code + self.assertIn("reverse=True", code) + + def test_associative_scan_edge_cases(self): + """Test associative_scan edge cases.""" + + # Single element + @helion.kernel(use_default_config=True) + def test_single_element(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + return result + + x_single = torch.tensor([[5.0]], device=DEVICE) + code, result = code_and_output(test_single_element, (x_single,)) + expected = torch.tensor([[5.0]], device=DEVICE) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Two elements + x_two = torch.tensor([[3.0, 7.0]], device=DEVICE) + code, result = code_and_output(test_single_element, (x_two,)) + expected = torch.tensor([[3.0, 10.0]], device=DEVICE) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + def test_associative_scan_large_scale(self): + """Test associative_scan with large tensors for performance validation.""" + + @helion.kernel(use_default_config=True) + def test_large_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + return result + + # Test with large tensor + x = torch.randn(32, 1024, device=DEVICE) + code, result = code_and_output(test_large_kernel, (x,)) + + # Verify correctness on large scale + expected = torch.cumsum(x, dim=1) + # Use relaxed tolerance for large tensors due to accumulated floating-point errors + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify output properties + self.assertEqual(result.shape, x.shape) + self.assertEqual(result.dtype, x.dtype) + + def test_associative_scan_torch_hops_mapping(self): + """Test that torch._higher_order_ops.associative_scan automatically maps to hl.associative_scan.""" + + @helion.kernel(use_default_config=True) + def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + # Use torch._higher_order_ops.associative_scan directly + result[i, :] = torch._higher_order_ops.associative_scan( + add_combine_fn, row_data, dim=1 + ) + return result + + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + device=DEVICE, + ) + + # Test that the kernel compiles and runs correctly + code, result = code_and_output(test_torch_hops_kernel, (x,)) + + # Expected prefix sum results + expected = torch.tensor( + [[1.0, 3.0, 6.0, 10.0], [5.0, 11.0, 18.0, 26.0]], + device=DEVICE, + ) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify the generated code contains the proper combine function and associative scan + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + self.assertIn("param_0 + param_1", code) + + def test_associative_scan_code_generation(self): + """Test that the generated code structure is correct.""" + + @helion.kernel(use_default_config=True) + def test_codegen_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(add_combine_fn, row_data, dim=1) + return result + + x = torch.tensor([[1.0, 2.0, 3.0]], device=DEVICE) + code, result = code_and_output(test_codegen_kernel, (x,)) + + # Check essential code structure + self.assertIn("@triton.jit", code) + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + self.assertIn("return", code) + + # Verify no placeholders remain + self.assertNotIn("TODO", code) + self.assertNotIn("placeholder", code) + + def test_associative_scan_jit_decorator_ignored(self): + """Test that @helion.jit decorator on combine functions is ignored.""" + + @helion.kernel(use_default_config=True) + def test_jit_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.associative_scan(jit_add_combine_fn, row_data, dim=1) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=DEVICE) + code, result = code_and_output(test_jit_kernel, (x,)) + + # Expected prefix sum results + expected = torch.tensor([[1.0, 3.0, 6.0, 10.0]], device=DEVICE) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify the generated code contains the proper combine function and associative scan + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + self.assertIn("param_0 + param_1", code) + # Verify @helion.jit decorator doesn't appear in generated code + self.assertNotIn("@helion.jit", code) + + def test_associative_scan_tuple_args(self): + """Test associative_scan with tuple arguments (matching GitHub issue #237 pattern).""" + + @helion.kernel(use_default_config=True) + def test_segmented_kernel( + indices: torch.Tensor, input_data: torch.Tensor + ) -> torch.Tensor: + E, C = input_data.shape + output = torch.zeros( + (E, C), dtype=input_data.dtype, device=input_data.device + ) + + for tile_e, tile_f in hl.tile([E, C]): + vals = input_data[tile_e, tile_f] + # Broadcast indices to match vals shape for the scan + idxs = indices[tile_e].unsqueeze(1).expand_as(vals) + + # Create tuple inside the device loop (as per GitHub issue example) + input_tuple = (vals, idxs) + + # Use torch._higher_order_ops.associative_scan as in the example + out_vals, out_idxs = torch._higher_order_ops.associative_scan( + helion_combine_fn, input_tuple, 0 + ) + + output[tile_e, tile_f] = out_vals + + return output + + # Create test data + E, C = 4, 2 + indices = torch.tensor( + [0.0, 0.0, 1.0, 1.0], device=DEVICE + ) # Use float to match input_data + input_data = torch.ones((E, C), device=DEVICE) + + code, result = code_and_output(test_segmented_kernel, (indices, input_data)) + + # Expected: cumulative sum for each position + expected = torch.tensor( + [[1.0, 1.0], [2.0, 2.0], [1.0, 1.0], [2.0, 2.0]], device=DEVICE + ) + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify the generated code structure + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + + def test_associative_scan_segmented_reduction(self): + """Test associative_scan for segmented reduction use case.""" + + @helion.kernel(use_default_config=True) + def segmented_scan_kernel( + indices: torch.Tensor, input_data: torch.Tensor + ) -> torch.Tensor: + E, C = input_data.shape + output = torch.zeros( + (E, C), dtype=input_data.dtype, device=input_data.device + ) + + for tile_e, tile_f in hl.tile([E, C]): + vals = input_data[tile_e, tile_f] + # Convert indices to float to match vals dtype and broadcast to match shape + idxs = indices[tile_e].float().unsqueeze(1).expand_as(vals) + + # Use tuple argument functionality for segmented scan + out_vals, _ = torch._higher_order_ops.associative_scan( + segmented_combine_fn, (vals, idxs), 0 + ) + + output[tile_e, tile_f] = out_vals + + return output + + # Test segmented reduction + E, C = 6, 3 + # Segments: [0,0], [1,1,1], [2] - three segments of sizes 2, 3, 1 + indices = torch.tensor([0, 0, 1, 1, 1, 2], device=DEVICE) + input_data = torch.ones((E, C), device=DEVICE) + + code, result = code_and_output(segmented_scan_kernel, (indices, input_data)) + + # Expected: cumulative sum within each segment + expected = torch.tensor( + [ + [1.0, 1.0, 1.0], # segment 0, position 0 + [2.0, 2.0, 2.0], # segment 0, position 1 + [1.0, 1.0, 1.0], # segment 1, position 0 + [2.0, 2.0, 2.0], # segment 1, position 1 + [3.0, 3.0, 3.0], # segment 1, position 2 + [1.0, 1.0, 1.0], # segment 2, position 0 + ], + device=DEVICE, + ) + + torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) + + # Verify the generated code structure + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + + def test_associative_scan_cumulative_argmax(self): + """Test cumulative argmax using tuple args with (float, int) types.""" + + @helion.kernel(use_default_config=True) + def cumulative_argmax_kernel( + input_data: torch.Tensor, positions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + max_values = torch.zeros_like(input_data) + max_indices = torch.zeros_like(input_data, dtype=torch.int32) + for tile_e in hl.tile(input_data.size(0)): + vals = input_data[tile_e, :] + # Convert positions to float to match vals dtype, then broadcast to match vals shape + indices = positions[:].to(torch.float32).unsqueeze(0).expand_as(vals) + + # Use hl.associative_scan directly with tuple args - return both values and indices + out_vals, out_indices = hl.associative_scan( + argmax_combine_fn, (vals, indices), dim=1 + ) + + max_values[tile_e, :] = out_vals + max_indices[tile_e, :] = out_indices.to(torch.int32) + + return max_values, max_indices + + input_data = torch.tensor( + [ + [1.0, 5.5, 2.0], + [3.0, 2.0, 4.0], + [2.0, 7.0, 1.0], + [4.1, 1.0, 3.0], + ], + device=DEVICE, + ) + positions = torch.tensor([0, 1, 2], device=DEVICE, dtype=torch.int32) + code, (result_values, result_indices) = code_and_output( + cumulative_argmax_kernel, (input_data, positions) + ) + + # Expected cumulative maximum values + expected_values = torch.tensor( + [ + [1.0, 5.5, 5.5], + [3.0, 3.0, 4.0], + [2.0, 7.0, 7.0], + [4.1, 4.1, 4.1], + ], + device=DEVICE, + ) + + # Expected indices of the maximum values (which row they came from) + expected_indices = torch.tensor( + [ + [0, 1, 1], + [0, 0, 2], + [0, 1, 1], + [0, 0, 0], + ], + device=DEVICE, + dtype=torch.int32, + ) + + torch.testing.assert_close(result_values, expected_values) + torch.testing.assert_close(result_indices, expected_indices) + + # Verify the generated code structure + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + + def test_associative_scan_in_helper_function(self): + """Test calling a function that internally uses hl.associative_scan.""" + + @helion.kernel(use_default_config=True) + def test_helper_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + # Use the cumsum_helper function which internally calls hl.associative_scan + result[i, :] = cumsum_helper(x[i, :]) + return result + + # Create test input + x = torch.tensor( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], + device=DEVICE, + ) + + # Test that the kernel compiles and runs + code, result = code_and_output(test_helper_kernel, (x,)) + + # Verify that the kernel runs successfully and produces output + self.assertEqual(result.shape, x.shape) + + # Verify that the helper function was used (output should be different from input) + self.assertFalse(torch.equal(result, x)) + + # Verify the generated code contains the helper function and associative scan + self.assertIn("def helper_function_", code) + self.assertIn("tl.associative_scan", code) + self.assertIn("param_0 + param_1", code) + + def test_cumsum_basic(self): + """Test basic cumsum functionality.""" + + @helion.kernel(use_default_config=True) + def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = torch.cumsum(row_data, dim=1) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], device=DEVICE) + + code, result = code_and_output(test_cumsum_kernel, (x,)) + + # Expected cumulative sum + expected = torch.tensor( + [[1.0, 3.0, 6.0, 10.0], [5.0, 11.0, 18.0, 26.0]], device=DEVICE + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains cumsum implementation + self.assertIn("def helper_function_", code) + self.assertIn("param_0 + param_1", code) + self.assertIn("tl.associative_scan", code) + + def test_cumsum_reverse(self): + """Test cumsum with reverse=True.""" + + @helion.kernel(use_default_config=True) + def test_cumsum_reverse_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.cumsum(row_data, dim=1, reverse=True) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=DEVICE) + + code, result = code_and_output(test_cumsum_reverse_kernel, (x,)) + + # For reverse cumsum: [10, 9, 7, 4] (sum from right to left) + expected = torch.tensor([[10.0, 9.0, 7.0, 4.0]], device=DEVICE) + torch.testing.assert_close(result, expected) + + # Verify reverse parameter is used + self.assertIn("reverse=True", code) + + def test_cumsum_different_dtypes(self): + """Test cumsum with different data types.""" + + for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: + with self.subTest(dtype=dtype): + + @helion.kernel(use_default_config=True) + def test_cumsum_dtype_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = torch.cumsum(row_data, dim=1) + return result + + x = torch.tensor( + [[1, 2, 3, 4], [5, 6, 7, 8]], device=DEVICE, dtype=dtype + ) + + code, result = code_and_output(test_cumsum_dtype_kernel, (x,)) + + # Verify output dtype matches input + self.assertEqual(result.dtype, x.dtype) + + # Check correctness + expected = torch.cumsum(x, dim=1) + # Convert expected to match result dtype if needed + if expected.dtype != result.dtype: + expected = expected.to(result.dtype) + torch.testing.assert_close(result, expected) + + def test_cumprod_basic(self): + """Test basic cumprod functionality.""" + + @helion.kernel(use_default_config=True) + def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = torch.cumprod(row_data, dim=1) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [2.0, 0.5, 3.0, 2.0]], device=DEVICE) + + code, result = code_and_output(test_cumprod_kernel, (x,)) + + # Expected cumulative product + expected = torch.tensor( + [[1.0, 2.0, 6.0, 24.0], [2.0, 1.0, 3.0, 6.0]], device=DEVICE + ) + torch.testing.assert_close(result, expected) + + # Verify the generated code contains cumprod implementation + self.assertIn("def helper_function_", code) + self.assertIn("param_0 * param_1", code) + self.assertIn("tl.associative_scan", code) + + def test_cumprod_reverse(self): + """Test cumprod with reverse=True.""" + + @helion.kernel(use_default_config=True) + def test_cumprod_reverse_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.cumprod(row_data, dim=1, reverse=True) + return result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=DEVICE) + + code, result = code_and_output(test_cumprod_reverse_kernel, (x,)) + + # For reverse cumprod: [24, 24, 12, 4] (product from right to left) + expected = torch.tensor([[24.0, 24.0, 12.0, 4.0]], device=DEVICE) + torch.testing.assert_close(result, expected) + + # Verify reverse parameter is used + self.assertIn("reverse=True", code) + + def test_cumprod_different_dtypes(self): + """Test cumprod with different data types.""" + + for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: + with self.subTest(dtype=dtype): + + @helion.kernel(use_default_config=True) + def test_cumprod_dtype_kernel(x: torch.Tensor) -> torch.Tensor: + result = torch.empty_like(x) + for i in hl.tile(x.size(0)): + row_data = x[i, :] + result[i, :] = hl.cumprod(row_data, dim=1) + return result + + x = torch.tensor( + [[1, 2, 3, 2], [2, 1, 2, 2]], device=DEVICE, dtype=dtype + ) + + code, result = code_and_output(test_cumprod_dtype_kernel, (x,)) + + # Verify output dtype matches input + self.assertEqual(result.dtype, x.dtype) + + # Check correctness + expected = torch.cumprod(x, dim=1) + # Convert expected to match result dtype if needed + if expected.dtype != result.dtype: + expected = expected.to(result.dtype) + torch.testing.assert_close(result, expected) + + def test_cumsum_cumprod_mixed(self): + """Test using both cumsum and cumprod in the same kernel.""" + + @helion.kernel(use_default_config=True) + def test_mixed_kernel(x: torch.Tensor) -> torch.Tensor: + sum_result = torch.empty_like(x) + prod_result = torch.empty_like(x) + + for i in hl.tile(x.size(0)): + row_data = x[i, :] + # Cumulative sum + sum_result[i, :] = torch.cumsum(row_data, dim=1) + # Cumulative product + prod_result[i, :] = torch.cumprod(row_data, dim=1) + + # Return sum for testing + return sum_result + + x = torch.tensor([[1.0, 2.0, 3.0, 4.0]], device=DEVICE) + + code, result = code_and_output(test_mixed_kernel, (x,)) + + # Test the sum result + expected_sum = torch.tensor([[1.0, 3.0, 6.0, 10.0]], device=DEVICE) + torch.testing.assert_close(result, expected_sum) + + # Verify both helper functions are generated + self.assertIn("helper_function_0", code) + self.assertIn("helper_function_1", code) + self.assertIn("param_0 + param_1", code) + self.assertIn("param_0 * param_1", code) + + +if __name__ == "__main__": + unittest.main()