|
3 | 3 | # SPDX-License-Identifier: MPL-2.0
|
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod
|
| 6 | +from contextlib import contextmanager |
| 7 | +from typing import Generator |
6 | 8 |
|
7 | 9 | import numpy as np
|
8 | 10 | from numpy._typing import NDArray
|
@@ -34,6 +36,14 @@ def nr_nodes(self):
|
34 | 36 | def nr_branches(self):
|
35 | 37 | """Returns the number of branches in the graph"""
|
36 | 38 |
|
| 39 | + @property |
| 40 | + def all_branches(self) -> Generator[tuple[int, int], None, None]: |
| 41 | + """Returns all branches in the graph.""" |
| 42 | + return ( |
| 43 | + (self.internal_to_external(source), self.internal_to_external(target)) |
| 44 | + for source, target in self._all_branches() |
| 45 | + ) |
| 46 | + |
37 | 47 | @abstractmethod
|
38 | 48 | def external_to_internal(self, ext_node_id: int) -> int:
|
39 | 49 | """Convert external node id to internal node id (internal)
|
@@ -63,6 +73,14 @@ def has_node(self, node_id: int) -> bool:
|
63 | 73 |
|
64 | 74 | return self._has_node(node_id=internal_node_id)
|
65 | 75 |
|
| 76 | + def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]: |
| 77 | + """Return all branches that have the node as an endpoint.""" |
| 78 | + int_node_id = self.external_to_internal(node_id) |
| 79 | + internal_edges = self._in_branches(int_node_id=int_node_id) |
| 80 | + return ( |
| 81 | + (self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges |
| 82 | + ) |
| 83 | + |
66 | 84 | def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
|
67 | 85 | """Add a node to the graph."""
|
68 | 86 | if self.has_node(ext_node_id):
|
@@ -158,12 +176,34 @@ def delete_branch_array(self, branch_array: BranchArray, raise_on_fail: bool = T
|
158 | 176 | if self._branch_is_relevant(branch):
|
159 | 177 | self.delete_branch(branch.from_node.item(), branch.to_node.item(), raise_on_fail=raise_on_fail)
|
160 | 178 |
|
161 |
| - def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool = True) -> None: |
| 179 | + def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool = True) -> None: |
162 | 180 | """Delete all branch3s in the branch3 array from the graph."""
|
163 |
| - for branch3 in branch_array: |
| 181 | + for branch3 in branch3_array: |
164 | 182 | branches = _get_branch3_branches(branch3)
|
165 | 183 | self.delete_branch_array(branches, raise_on_fail=raise_on_fail)
|
166 | 184 |
|
| 185 | + @contextmanager |
| 186 | + def tmp_remove_nodes(self, nodes: list[int]) -> Generator: |
| 187 | + """Context manager that temporarily removes nodes and their branches from the graph. |
| 188 | + Example: |
| 189 | + >>> with graph.tmp_remove_nodes([1, 2, 3]): |
| 190 | + >>> assert not graph.has_node(1) |
| 191 | + >>> assert graph.has_node(1) |
| 192 | + In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without |
| 193 | + considering certain nodes. |
| 194 | + """ |
| 195 | + edge_list = [] |
| 196 | + for node in nodes: |
| 197 | + edge_list += list(self.in_branches(node)) |
| 198 | + self.delete_node(node) |
| 199 | + |
| 200 | + yield |
| 201 | + |
| 202 | + for node in nodes: |
| 203 | + self.add_node(node) |
| 204 | + for source, target in edge_list: |
| 205 | + self.add_branch(source, target) |
| 206 | + |
167 | 207 | def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
|
168 | 208 | """Calculate the shortest path between two nodes
|
169 | 209 |
|
@@ -235,8 +275,49 @@ def get_connected(
|
235 | 275 | nodes_to_ignore=self._externals_to_internals(nodes_to_ignore),
|
236 | 276 | inclusive=inclusive,
|
237 | 277 | )
|
| 278 | + |
238 | 279 | return self._internals_to_externals(nodes)
|
239 | 280 |
|
| 281 | + def find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: |
| 282 | + """Find the first connected node to the node_id from the candidate_node_ids |
| 283 | +
|
| 284 | + Note: |
| 285 | + If multiple candidate nodes are connected to the node, the first one found is returned. |
| 286 | + There is no guarantee that the same candidate node will be returned each time. |
| 287 | +
|
| 288 | + Raises: |
| 289 | + MissingNodeError: if no connected node is found |
| 290 | + ValueError: if the node_id is in candidate_node_ids |
| 291 | + """ |
| 292 | + internal_node_id = self.external_to_internal(node_id) |
| 293 | + internal_candidates = self._externals_to_internals(candidate_node_ids) |
| 294 | + if internal_node_id in internal_candidates: |
| 295 | + raise ValueError("node_id cannot be in candidate_node_ids") |
| 296 | + return self.internal_to_external(self._find_first_connected(internal_node_id, internal_candidates)) |
| 297 | + |
| 298 | + def get_downstream_nodes(self, node_id: int, start_node_ids: list[int], inclusive: bool = False) -> list[int]: |
| 299 | + """Find all nodes downstream of the node_id with respect to the start_node_ids |
| 300 | +
|
| 301 | + Example: |
| 302 | + given this graph: [1] - [2] - [3] - [4] |
| 303 | + >>> graph.get_downstream_nodes(2, [1]) == [3, 4] |
| 304 | + >>> graph.get_downstream_nodes(2, [1], inclusive=True) == [2, 3, 4] |
| 305 | +
|
| 306 | + args: |
| 307 | + node_id: node id to start the search from |
| 308 | + start_node_ids: list of node ids considered 'above' the node_id |
| 309 | + inclusive: whether to include the given node id in the result |
| 310 | + returns: |
| 311 | + list of node ids sorted by distance, downstream of to the node id |
| 312 | + """ |
| 313 | + connected_node = self.find_first_connected(node_id, start_node_ids) |
| 314 | + path, _ = self.get_shortest_path(node_id, connected_node) |
| 315 | + _, upstream_node, *_ = ( |
| 316 | + path # path is at least 2 elements long or find_first_connected would have raised an error |
| 317 | + ) |
| 318 | + |
| 319 | + return self.get_connected(node_id, [upstream_node], inclusive) |
| 320 | + |
240 | 321 | def find_fundamental_cycles(self) -> list[list[int]]:
|
241 | 322 | """Find all fundamental cycles in the graph.
|
242 | 323 | Returns:
|
@@ -270,9 +351,15 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
|
270 | 351 | return branch.is_active.item()
|
271 | 352 | return True
|
272 | 353 |
|
| 354 | + @abstractmethod |
| 355 | + def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ... |
| 356 | + |
273 | 357 | @abstractmethod
|
274 | 358 | def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
|
275 | 359 |
|
| 360 | + @abstractmethod |
| 361 | + def _find_first_connected(self, node_id: int, candidate_node_ids: list[int]) -> int: ... |
| 362 | + |
276 | 363 | @abstractmethod
|
277 | 364 | def _has_branch(self, from_node_id, to_node_id) -> bool: ...
|
278 | 365 |
|
@@ -307,6 +394,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
|
307 | 394 | @abstractmethod
|
308 | 395 | def _find_fundamental_cycles(self) -> list[list[int]]: ...
|
309 | 396 |
|
| 397 | + @abstractmethod |
| 398 | + def _all_branches(self) -> Generator[tuple[int, int], None, None]: ... |
| 399 | + |
310 | 400 |
|
311 | 401 | def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
|
312 | 402 | node_1 = branch3.node_1.item()
|
|
0 commit comments