|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from abc import ABC |
| 4 | +from abc import abstractmethod |
| 5 | +import ast |
| 6 | +from typing import TYPE_CHECKING |
| 7 | +from typing import cast |
| 8 | + |
| 9 | +from .ast_extension import create |
| 10 | +from .ast_extension import create_arg |
| 11 | +from .ast_extension import create_arguments |
| 12 | +from .ast_extension import expr_from_string |
| 13 | +from .ast_extension import statement_from_string |
| 14 | + |
| 15 | +if TYPE_CHECKING: |
| 16 | + import types |
| 17 | + |
| 18 | + from .device_function import DeviceFunction |
| 19 | + from .device_ir import HelperFunctionGraphInfo |
| 20 | + |
| 21 | + |
| 22 | +class CodegenInterface(ABC): |
| 23 | + """Abstract base class for codegen interfaces used by GraphInterpreter.""" |
| 24 | + |
| 25 | + def __init__(self, device_function: DeviceFunction) -> None: |
| 26 | + self.device_function = device_function |
| 27 | + |
| 28 | + @abstractmethod |
| 29 | + def add_statement(self, stmt: ast.AST | str | None) -> None: |
| 30 | + """Add a statement to the generated code.""" |
| 31 | + |
| 32 | + def tmpvar(self, *, dce: bool = False, prefix: str = "v") -> str: |
| 33 | + """Generate a temporary variable name.""" |
| 34 | + return self.device_function.unique_name(prefix, dce=dce) |
| 35 | + |
| 36 | + def lift(self, expr: ast.AST, *, dce: bool = False, prefix: str = "v") -> ast.Name: |
| 37 | + """Lift an expression to a temporary variable if needed.""" |
| 38 | + if isinstance(expr, ast.Name): |
| 39 | + return expr |
| 40 | + varname = self.tmpvar(dce=dce, prefix=prefix) |
| 41 | + self.add_statement(statement_from_string(f"{varname} = expr", expr=expr)) |
| 42 | + return create(ast.Name, id=varname, ctx=ast.Load()) |
| 43 | + |
| 44 | + |
| 45 | +def extract_helper_function(helper_fn: object) -> types.FunctionType: |
| 46 | + """Extract the actual function from a Kernel object or return as-is. |
| 47 | +
|
| 48 | + This utility function centralizes the logic for handling both regular functions |
| 49 | + and Kernel objects that wrap functions. |
| 50 | + """ |
| 51 | + from ..runtime.kernel import Kernel |
| 52 | + |
| 53 | + # pyre-ignore[16]: We check isinstance before accessing .fn |
| 54 | + return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn |
| 55 | + |
| 56 | + |
| 57 | +class HelperCodegen(CodegenInterface): |
| 58 | + """Codegen wrapper for helper function generation.""" |
| 59 | + |
| 60 | + def __init__(self, device_function: DeviceFunction) -> None: |
| 61 | + super().__init__(device_function) |
| 62 | + |
| 63 | + def add_statement(self, stmt: ast.AST | str | None) -> None: |
| 64 | + if stmt is not None: |
| 65 | + if isinstance(stmt, str): |
| 66 | + stmt = statement_from_string(stmt) |
| 67 | + self.device_function.body.append(stmt) |
| 68 | + |
| 69 | + |
| 70 | +class HelperFunctionManager: |
| 71 | + """Manages helper function registration and code generation.""" |
| 72 | + |
| 73 | + def __init__(self) -> None: |
| 74 | + self.helper_functions: dict[str, HelperFunctionGraphInfo] = {} |
| 75 | + |
| 76 | + def register_helper_function( |
| 77 | + self, helper_graph_info: HelperFunctionGraphInfo |
| 78 | + ) -> None: |
| 79 | + """Register a helper function to be generated at global scope.""" |
| 80 | + self.helper_functions[helper_graph_info.name] = helper_graph_info |
| 81 | + |
| 82 | + def codegen_helper_functions(self) -> list[ast.stmt]: |
| 83 | + """Generate helper function definitions at global scope.""" |
| 84 | + helper_defs = [] |
| 85 | + for helper_graph_info in self.helper_functions.values(): |
| 86 | + # Determine the number of parameters from the graph |
| 87 | + input_nodes = helper_graph_info.find_input_nodes() |
| 88 | + |
| 89 | + # Generate argument list with consistent names |
| 90 | + args = [] |
| 91 | + param_names = [] |
| 92 | + for i in range(len(input_nodes)): |
| 93 | + arg_name = f"param_{i}" |
| 94 | + args.append(create_arg(arg_name)) |
| 95 | + param_names.append(arg_name) |
| 96 | + |
| 97 | + # Store parameter names for use in body generation |
| 98 | + helper_graph_info._param_names = param_names |
| 99 | + |
| 100 | + # Process the FX graph to generate the correct helper function body |
| 101 | + func_body = self._codegen_helper_function_body(helper_graph_info) |
| 102 | + |
| 103 | + # Generate the function structure with @triton.jit decorator |
| 104 | + func_def = create( |
| 105 | + ast.FunctionDef, |
| 106 | + name=helper_graph_info.name, |
| 107 | + args=create_arguments(args), |
| 108 | + body=func_body, |
| 109 | + decorator_list=[expr_from_string("triton.jit")], |
| 110 | + type_params=[], |
| 111 | + ) |
| 112 | + |
| 113 | + helper_defs.append(func_def) |
| 114 | + |
| 115 | + return helper_defs |
| 116 | + |
| 117 | + def _codegen_helper_function_body( |
| 118 | + self, helper_graph_info: HelperFunctionGraphInfo |
| 119 | + ) -> list[ast.stmt]: |
| 120 | + """Generate the body of a helper function by processing its FX graph.""" |
| 121 | + temp_device_function = self._create_temp_device_function(helper_graph_info) |
| 122 | + param_args = self._create_parameter_args(helper_graph_info) |
| 123 | + |
| 124 | + with temp_device_function: |
| 125 | + results = self._process_helper_graph( |
| 126 | + helper_graph_info, temp_device_function, param_args |
| 127 | + ) |
| 128 | + statements = temp_device_function.body.copy() |
| 129 | + self._ensure_return_statement(statements, results, helper_graph_info.name) |
| 130 | + |
| 131 | + return cast("list[ast.stmt]", statements) |
| 132 | + |
| 133 | + def _create_temp_device_function( |
| 134 | + self, helper_graph_info: HelperFunctionGraphInfo |
| 135 | + ) -> DeviceFunction: |
| 136 | + """Create a temporary DeviceFunction for helper function generation.""" |
| 137 | + # Import here to avoid circular imports |
| 138 | + from .device_function import DeviceFunction |
| 139 | + |
| 140 | + current = DeviceFunction.current() |
| 141 | + |
| 142 | + return DeviceFunction( |
| 143 | + name=f"temp_{helper_graph_info.name}", |
| 144 | + config=current.config, |
| 145 | + codegen=current.codegen, |
| 146 | + ) |
| 147 | + |
| 148 | + def _create_parameter_args( |
| 149 | + self, helper_graph_info: HelperFunctionGraphInfo |
| 150 | + ) -> list[ast.AST]: |
| 151 | + """Create parameter AST nodes for the helper function.""" |
| 152 | + param_names = helper_graph_info._param_names |
| 153 | + return [expr_from_string(param_name) for param_name in param_names] |
| 154 | + |
| 155 | + def _process_helper_graph( |
| 156 | + self, |
| 157 | + helper_graph_info: HelperFunctionGraphInfo, |
| 158 | + temp_device_function: DeviceFunction, |
| 159 | + param_args: list[ast.AST], |
| 160 | + ) -> object: |
| 161 | + """Process the graph using the existing interpreter infrastructure.""" |
| 162 | + from .inductor_lowering import GraphInterpreter |
| 163 | + |
| 164 | + helper_codegen = HelperCodegen(temp_device_function) |
| 165 | + interpreter = GraphInterpreter(helper_graph_info.graph, helper_codegen) |
| 166 | + return interpreter.run(*param_args) |
| 167 | + |
| 168 | + def _ensure_return_statement( |
| 169 | + self, statements: list[ast.AST], results: object, function_name: str |
| 170 | + ) -> None: |
| 171 | + """Ensure the function body has a proper return statement.""" |
| 172 | + if statements and isinstance(statements[-1], ast.Return): |
| 173 | + return |
| 174 | + |
| 175 | + if isinstance(results, ast.AST): |
| 176 | + statements.append(create(ast.Return, value=results)) |
| 177 | + elif isinstance(results, (list, tuple)) and all( |
| 178 | + isinstance(r, ast.AST) for r in results |
| 179 | + ): |
| 180 | + tuple_ast = create(ast.Tuple, elts=list(results), ctx=ast.Load()) |
| 181 | + statements.append(create(ast.Return, value=tuple_ast)) |
| 182 | + else: |
| 183 | + raise RuntimeError( |
| 184 | + f"Helper function {function_name} produced invalid result: {type(results)} {results}" |
| 185 | + ) |
| 186 | + |
| 187 | + |
| 188 | +def codegen_helper_function_graph_info( |
| 189 | + helper_graph_info: HelperFunctionGraphInfo, state: object |
| 190 | +) -> list[object]: |
| 191 | + """Generate code for HelperFunctionGraphInfo objects.""" |
| 192 | + from .inductor_lowering import CodegenState |
| 193 | + from .inductor_lowering import codegen_call_with_graph |
| 194 | + |
| 195 | + if not isinstance(state, CodegenState): |
| 196 | + raise TypeError(f"Expected CodegenState, got {type(state)}") |
| 197 | + |
| 198 | + # For helper functions, we need to inline the function body |
| 199 | + # The helper function takes variable arguments and returns their combination |
| 200 | + |
| 201 | + # Generate temporary variable names for the helper function arguments |
| 202 | + # Use the graph's input nodes to determine the number of parameters |
| 203 | + input_nodes = helper_graph_info.find_input_nodes() |
| 204 | + args: list[ast.AST] = [] |
| 205 | + |
| 206 | + for i in range(len(input_nodes)): |
| 207 | + var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}") |
| 208 | + args.append(create(ast.Name, id=var_name, ctx=ast.Load())) |
| 209 | + |
| 210 | + # Generate the helper function call |
| 211 | + return codegen_call_with_graph(state.codegen, helper_graph_info.graph, args) |
0 commit comments