Skip to content

Commit 1154f96

Browse files
committed
Add tmp-remove-nodes method
Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com>
1 parent 4a36057 commit 1154f96

File tree

4 files changed

+48
-1
lines changed

4 files changed

+48
-1
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.0
1+
1.1

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6+
from contextlib import contextmanager
7+
from typing import Generator
68

79
import numpy as np
810
from numpy._typing import NDArray
@@ -164,6 +166,31 @@ def delete_branch3_array(self, branch_array: Branch3Array, raise_on_fail: bool =
164166
branches = _get_branch3_branches(branch3)
165167
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)
166168

169+
@contextmanager
170+
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
171+
"""Context manager that temporarily removes nodes and their branches from the graph.
172+
Example:
173+
>>> with graph.tmp_remove_nodes([1, 2, 3]):
174+
>>> assert not graph.has_node(1)
175+
>>> assert graph.has_node(1)
176+
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
177+
considering certain nodes.
178+
"""
179+
edge_list = []
180+
for node in nodes:
181+
internal_node = self.external_to_internal(node)
182+
node_edges = [
183+
(self.internal_to_external(source), self.internal_to_external(target))
184+
for source, target in self._in_edges(internal_node)
185+
]
186+
edge_list += node_edges
187+
self._delete_node(internal_node)
188+
yield edge_list
189+
for node in nodes:
190+
self.add_node(node)
191+
for source, target in edge_list:
192+
self.add_branch(source, target)
193+
167194
def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
168195
"""Calculate the shortest path between two nodes
169196
@@ -270,6 +297,13 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
270297
return branch.is_active.item()
271298
return True
272299

300+
@abstractmethod
301+
def _in_edges(self, internal_node: int) -> list[tuple[int, int]]:
302+
"""Return all edges a node occurs in.
303+
304+
Return a list of tuples with the source and target node id. These are internal node ids.
305+
"""
306+
273307
@abstractmethod
274308
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...
275309

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo
9999

100100
return connected_nodes
101101

102+
def _in_edges(self, internal_node: int) -> list[tuple[int, int]]:
103+
return [(source, target) for source, target, _ in self._graph.in_edges(internal_node)]
104+
102105
def _find_fundamental_cycles(self) -> list[list[int]]:
103106
"""Find all fundamental cycles in the graph using Rustworkx.
104107

tests/unit/model/graphs/test_graph_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ def test_graph_initialize(graph):
2727
assert 1 == graph.nr_branches
2828

2929

30+
def test_graph_has_branch(graph):
31+
graph.add_node(1)
32+
graph.add_node(2)
33+
graph.add_branch(1, 2)
34+
35+
assert graph.has_branch(1, 2)
36+
assert graph.has_branch(2, 1) # reversed should work too
37+
assert not graph.has_branch(1, 3)
38+
39+
3040
def test_graph_delete_branch(graph):
3141
"""Test whether a branch is deleted correctly"""
3242
graph.add_node(1)

0 commit comments

Comments
 (0)