|
8 | 8 | from datetime import datetime
|
9 | 9 | from enum import IntEnum
|
10 | 10 | from pathlib import Path
|
11 |
| -from typing import Any, Dict, Iterator, List, Optional, TextIO, Tuple, Type, TypeVar, Union |
| 11 | +from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union |
12 | 12 |
|
13 | 13 | import networkx # type:ignore
|
14 | 14 |
|
@@ -1042,24 +1042,75 @@ def _add_missing_input_defaults(component_inputs: Dict[str, Any], component_inpu
|
1042 | 1042 |
|
1043 | 1043 | return component_inputs
|
1044 | 1044 |
|
| 1045 | + def _tiebreak_waiting_components( |
| 1046 | + self, |
| 1047 | + component_name: str, |
| 1048 | + priority: ComponentPriority, |
| 1049 | + priority_queue: FIFOPriorityQueue, |
| 1050 | + topological_sort: Union[Dict[str, int], None], |
| 1051 | + ): |
| 1052 | + """ |
| 1053 | + Decides which component to run when multiple components are waiting for inputs with the same priority. |
| 1054 | +
|
| 1055 | + :param component_name: The name of the component. |
| 1056 | + :param priority: Priority of the component. |
| 1057 | + :param priority_queue: Priority queue of component names. |
| 1058 | + :param topological_sort: Cached topological sort of all components in the pipeline. |
| 1059 | + """ |
| 1060 | + components_with_same_priority = [component_name] |
| 1061 | + |
| 1062 | + while len(priority_queue) > 0: |
| 1063 | + next_priority, next_component_name = priority_queue.peek() |
| 1064 | + if next_priority == priority: |
| 1065 | + priority_queue.pop() # actually remove the component |
| 1066 | + components_with_same_priority.append(next_component_name) |
| 1067 | + else: |
| 1068 | + break |
| 1069 | + |
| 1070 | + if len(components_with_same_priority) > 1: |
| 1071 | + if topological_sort is None: |
| 1072 | + if networkx.is_directed_acyclic_graph(self.graph): |
| 1073 | + topological_sort = networkx.lexicographical_topological_sort(self.graph) |
| 1074 | + topological_sort = {node: idx for idx, node in enumerate(topological_sort)} |
| 1075 | + else: |
| 1076 | + condensed = networkx.condensation(self.graph) |
| 1077 | + condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))} |
| 1078 | + topological_sort = { |
| 1079 | + component_name: condensed_sorted[node] |
| 1080 | + for component_name, node in condensed.graph["mapping"].items() |
| 1081 | + } |
| 1082 | + |
| 1083 | + components_with_same_priority = sorted( |
| 1084 | + components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower()) |
| 1085 | + ) |
| 1086 | + |
| 1087 | + component_name = components_with_same_priority[0] |
| 1088 | + |
| 1089 | + return component_name, topological_sort |
| 1090 | + |
1045 | 1091 | @staticmethod
|
1046 | 1092 | def _write_component_outputs(
|
1047 |
| - component_name, component_outputs, inputs, receivers, include_outputs_from |
| 1093 | + component_name: str, |
| 1094 | + component_outputs: Dict[str, Any], |
| 1095 | + inputs: Dict[str, Any], |
| 1096 | + receivers: List[Tuple], |
| 1097 | + include_outputs_from: Set[str], |
1048 | 1098 | ) -> Dict[str, Any]:
|
1049 | 1099 | """
|
1050 | 1100 | Distributes the outputs of a component to the input sockets that it is connected to.
|
1051 | 1101 |
|
1052 | 1102 | :param component_name: The name of the component.
|
1053 | 1103 | :param component_outputs: The outputs of the component.
|
1054 | 1104 | :param inputs: The current global input state.
|
1055 |
| - :param receivers: List of receiver_name, sender_socket, receiver_socket for connected components. |
| 1105 | + :param receivers: List of components that receive inputs from the component. |
1056 | 1106 | :param include_outputs_from: List of component names that should always return an output from the pipeline.
|
1057 | 1107 | """
|
1058 | 1108 | for receiver_name, sender_socket, receiver_socket in receivers:
|
1059 | 1109 | # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
|
1060 | 1110 | # that the sender did not produce an output for this socket.
|
1061 | 1111 | # This allows us to track if a pre-decessor already ran but did not produce an output.
|
1062 | 1112 | value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
|
| 1113 | + |
1063 | 1114 | if receiver_name not in inputs:
|
1064 | 1115 | inputs[receiver_name] = {}
|
1065 | 1116 |
|
|
0 commit comments