|
2 | 2 |
|
3 | 3 | import ast
|
4 | 4 | import builtins
|
5 |
| -from collections.abc import Callable |
6 | 5 | import contextlib
|
7 | 6 | import dataclasses
|
8 | 7 | import functools
|
@@ -339,7 +338,11 @@ def build_rolled_reductions(self) -> None:
|
339 | 338 | for graph_id, graph_info in enumerate([*self.graphs]):
|
340 | 339 | assert graph_id == graph_info.graph_id
|
341 | 340 | roller = ReductionRoller(self, rdim, graph_to_info)
|
342 |
| - new_graph = roller.process(graph_info.graph) |
| 341 | + try: |
| 342 | + new_graph = roller.process(graph_info.graph) |
| 343 | + except NotImplementedError: |
| 344 | + first = False |
| 345 | + break |
343 | 346 | new_graph_id = self.add_graph(
|
344 | 347 | new_graph, type(graph_info), **graph_info.kwargs()
|
345 | 348 | )
|
@@ -807,9 +810,160 @@ def visit_Call(self, node: ast.Call) -> object:
|
807 | 810 | else:
|
808 | 811 | func = self.visit(node.func)
|
809 | 812 |
|
| 813 | + # Special handling for associative_scan |
| 814 | + if isinstance( |
| 815 | + (func_type_info := node.func._type_info), |
| 816 | + CallableType, |
| 817 | + ) and ( |
| 818 | + func_type_info.value is hl.associative_scan or func is hl.associative_scan |
| 819 | + ): |
| 820 | + return self._handle_associative_scan(node, args, kwargs) |
| 821 | + |
810 | 822 | # pyre-ignore[6]
|
811 | 823 | return _CheckForIndexCalls.retry_call(func, args, kwargs)
|
812 | 824 |
|
| 825 | + def _handle_associative_scan( |
| 826 | + self, node: ast.Call, args: list[object], kwargs: dict[str, object] |
| 827 | + ) -> object: |
| 828 | + """Handle associative_scan calls by tracing the combine function as a subgraph.""" |
| 829 | + from ..language import _tracing_ops |
| 830 | + |
| 831 | + combine_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = cast( |
| 832 | + "Callable[[torch.Tensor, torch.Tensor], torch.Tensor]", args[0] |
| 833 | + ) # The combine function |
| 834 | + input_tensor = args[1] # The input tensor |
| 835 | + |
| 836 | + # Extract other arguments from kwargs |
| 837 | + dim = kwargs.get("dim", 0) |
| 838 | + reverse = kwargs.get("reverse", False) |
| 839 | + |
| 840 | + # Detect if we're dealing with tuple inputs |
| 841 | + is_tuple_input = isinstance(input_tensor, (tuple, list)) |
| 842 | + |
| 843 | + # Create a subgraph for the combine function |
| 844 | + if is_tuple_input: |
| 845 | + |
| 846 | + def run_combine_subgraph( |
| 847 | + *args: torch.Tensor, |
| 848 | + ) -> tuple[torch.Tensor, ...]: |
| 849 | + # This will trace the combine function with unpacked tuple inputs |
| 850 | + from .helper_function import extract_helper_function |
| 851 | + |
| 852 | + # For tuple inputs, the combine function expects unpacked arguments |
| 853 | + # args = [left_val1, left_val2, ..., right_val1, right_val2, ...] |
| 854 | + # We need to call: combine_fn(left_val1, left_val2, ..., right_val1, right_val2, ...) |
| 855 | + actual_fn = extract_helper_function(combine_fn) |
| 856 | + result = actual_fn(*args) |
| 857 | + return result if isinstance(result, tuple) else (result,) |
| 858 | + else: |
| 859 | + |
| 860 | + def run_combine_subgraph(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| 861 | + # This will trace the combine function with tensor inputs |
| 862 | + from .helper_function import extract_helper_function |
| 863 | + |
| 864 | + actual_fn = extract_helper_function(combine_fn) |
| 865 | + return actual_fn(a, b) |
| 866 | + |
| 867 | + # Create fake inputs for the combine function |
| 868 | + if is_tuple_input: |
| 869 | + # Handle tuple inputs |
| 870 | + if isinstance(input_tensor, (tuple, list)) and all( |
| 871 | + isinstance(t, torch.Tensor) for t in input_tensor |
| 872 | + ): |
| 873 | + fake_inputs = [] |
| 874 | + for tensor in input_tensor: |
| 875 | + fake_inputs.extend( |
| 876 | + [ |
| 877 | + torch.empty([1], dtype=tensor.dtype, device=tensor.device), |
| 878 | + torch.empty([1], dtype=tensor.dtype, device=tensor.device), |
| 879 | + ] |
| 880 | + ) |
| 881 | + fake_a_and_b = fake_inputs |
| 882 | + else: |
| 883 | + # Fallback for when input_tensor is a proxy tuple |
| 884 | + # Assume 2 tensors of float32 for now |
| 885 | + fake_a_and_b = [torch.empty([1], dtype=torch.float32) for _ in range(4)] |
| 886 | + else: |
| 887 | + # Handle single tensor inputs |
| 888 | + if isinstance(input_tensor, torch.Tensor): |
| 889 | + fake_a = torch.empty( |
| 890 | + [1], dtype=input_tensor.dtype, device=input_tensor.device |
| 891 | + ) |
| 892 | + fake_b = torch.empty( |
| 893 | + [1], dtype=input_tensor.dtype, device=input_tensor.device |
| 894 | + ) |
| 895 | + fake_a_and_b = [fake_a, fake_b] |
| 896 | + else: |
| 897 | + # Fallback for when input_tensor is a proxy |
| 898 | + fake_a = torch.empty([1], dtype=torch.float32) |
| 899 | + fake_b = torch.empty([1], dtype=torch.float32) |
| 900 | + fake_a_and_b = [fake_a, fake_b] |
| 901 | + |
| 902 | + with self.disable_tracing() as tracer: |
| 903 | + combine_graph = proxy_tensor.make_fx( |
| 904 | + run_combine_subgraph, decomposition_table=select_decomp_table() |
| 905 | + )(*fake_a_and_b).graph |
| 906 | + |
| 907 | + combine_graph_id = self.device_ir.add_graph( |
| 908 | + combine_graph, |
| 909 | + HelperFunctionGraphInfo, |
| 910 | + node_args=[], # The combine function doesn't use external args |
| 911 | + ) |
| 912 | + |
| 913 | + # Create the associative_scan tracing operation |
| 914 | + scan_args = ( |
| 915 | + combine_graph_id, |
| 916 | + input_tensor, |
| 917 | + dim, |
| 918 | + reverse, |
| 919 | + is_tuple_input, |
| 920 | + ) |
| 921 | + |
| 922 | + proxy_args, proxy_kwargs = args_to_proxies(tracer, scan_args) |
| 923 | + proxy_out = tracer.create_proxy( |
| 924 | + "call_function", |
| 925 | + _tracing_ops._associative_scan, |
| 926 | + proxy_args, |
| 927 | + proxy_kwargs, |
| 928 | + ) |
| 929 | + |
| 930 | + # The output has the same shape as the input |
| 931 | + if is_tuple_input: |
| 932 | + # For tuple inputs, track each element separately and return a tuple |
| 933 | + proxy_tensor.track_tensor_tree( |
| 934 | + input_tensor, |
| 935 | + proxy_out, |
| 936 | + constant=None, |
| 937 | + tracer=tracer, |
| 938 | + ) |
| 939 | + # Convert the proxy output to a tuple of individual proxies |
| 940 | + tuple_proxies = [] |
| 941 | + assert isinstance( |
| 942 | + input_tensor, (tuple, list) |
| 943 | + ) # Guaranteed when is_tuple_input is True |
| 944 | + for i, tensor in enumerate(input_tensor): |
| 945 | + element_proxy = tracer.create_proxy( |
| 946 | + "call_function", |
| 947 | + operator.getitem, |
| 948 | + (proxy_out, i), |
| 949 | + {}, |
| 950 | + ) |
| 951 | + proxy_tensor.track_tensor_tree( |
| 952 | + tensor, |
| 953 | + element_proxy, |
| 954 | + constant=None, |
| 955 | + tracer=tracer, |
| 956 | + ) |
| 957 | + tuple_proxies.append(tensor) |
| 958 | + return tuple(tuple_proxies) |
| 959 | + proxy_tensor.track_tensor_tree( |
| 960 | + input_tensor, |
| 961 | + proxy_out, |
| 962 | + constant=None, |
| 963 | + tracer=tracer, |
| 964 | + ) |
| 965 | + return proxy_out |
| 966 | + |
813 | 967 | def visit_Attribute(self, node: ast.Attribute) -> object:
|
814 | 968 | return getattr(self.visit(node.value), node.attr)
|
815 | 969 |
|
@@ -898,6 +1052,37 @@ def lower_to_device_ir(func: HostFunction) -> DeviceIR:
|
898 | 1052 | return device_ir
|
899 | 1053 |
|
900 | 1054 |
|
| 1055 | +@dataclasses.dataclass |
| 1056 | +class HelperFunctionGraphInfo(NodeArgsGraphInfo): |
| 1057 | + """Graph info for helper functions in higher-order operations like associative_scan.""" |
| 1058 | + |
| 1059 | + _param_names: list[str] = dataclasses.field(default_factory=list) |
| 1060 | + |
| 1061 | + @property |
| 1062 | + def name(self) -> str: |
| 1063 | + return f"helper_function_{self.graph_id}" |
| 1064 | + |
| 1065 | + def find_input_nodes(self) -> list[torch.fx.Node]: |
| 1066 | + """Find all placeholder nodes (inputs) in the graph.""" |
| 1067 | + return self.graph.find_nodes(op="placeholder") |
| 1068 | + |
| 1069 | + def codegen(self, state: CodegenState) -> list[object]: |
| 1070 | + # For helper functions, we need to inline the function body |
| 1071 | + # The helper function takes variable arguments and returns their combination |
| 1072 | + |
| 1073 | + # Generate temporary variable names for the helper function arguments |
| 1074 | + # Use the graph's input nodes to determine the number of parameters |
| 1075 | + input_nodes = self.find_input_nodes() |
| 1076 | + args: list[ast.AST] = [] |
| 1077 | + |
| 1078 | + for i in range(len(input_nodes)): |
| 1079 | + var_name = state.codegen.tmpvar(prefix=f"helper_arg_{i}") |
| 1080 | + args.append(create(ast.Name, id=var_name, ctx=ast.Load())) |
| 1081 | + |
| 1082 | + # Generate the helper function call |
| 1083 | + return codegen_call_with_graph(state.codegen, self.graph, args) |
| 1084 | + |
| 1085 | + |
901 | 1086 | def remove_unnecessary_tile_index(graph: torch.fx.Graph) -> None:
|
902 | 1087 | """
|
903 | 1088 | Remove unnecessary tile_index nodes from the graph.
|
|
0 commit comments