Skip to content

Commit d9521f6

Browse files
committed
merge all_branches
Signed-off-by: Thijs Baaijen <13253091+Thijss@users.noreply.github.com>
2 parents 45780e7 + 3daad3f commit d9521f6

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

src/power_grid_model_ds/_core/model/graphs/models/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
from abc import ABC, abstractmethod
6-
from contextlib import contextmanager
76
from typing import Generator
87

98
import numpy as np
@@ -37,11 +36,12 @@ def nr_branches(self):
3736
"""Returns the number of branches in the graph"""
3837

3938
@property
40-
@abstractmethod
41-
def all_branches(self) -> list[frozenset[int]]:
42-
"""Returns all branches in the graph as a list of node pairs (frozensets).
43-
Warning: Depending on graph engine, performance could be slow for large graphs
44-
"""
39+
def all_branches(self) -> Generator[tuple[int, int], None, None]:
40+
"""Returns all branches in the graph."""
41+
return (
42+
(self.internal_to_external(source), self.internal_to_external(target))
43+
for source, target in self._all_branches()
44+
)
4545

4646
@abstractmethod
4747
def external_to_internal(self, ext_node_id: int) -> int:
@@ -348,6 +348,9 @@ def _get_components(self, substation_nodes: list[int]) -> list[list[int]]: ...
348348
@abstractmethod
349349
def _find_fundamental_cycles(self) -> list[list[int]]: ...
350350

351+
@abstractmethod
352+
def _all_branches(self) -> Generator[tuple[int, int], None, None]: ...
353+
351354

352355
def _get_branch3_branches(branch3: Branch3Array) -> BranchArray:
353356
node_1 = branch3.node_1.item()

src/power_grid_model_ds/_core/model/graphs/models/rustworkx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MPL-2.0
44

55
import logging
6+
from typing import Generator
67

78
import rustworkx as rx
89
from rustworkx import NoEdgeBetweenNodes
@@ -119,6 +120,9 @@ def _find_fundamental_cycles(self) -> list[list[int]]:
119120
"""
120121
return find_fundamental_cycles_rustworkx(self._graph)
121122

123+
def _all_branches(self) -> Generator[tuple[int, int], None, None]:
124+
return ((source, target) for source, target in self._graph.edge_list())
125+
122126

123127
class _NodeVisitor(BFSVisitor):
124128
def __init__(self, nodes_to_ignore: list[int]):

tests/unit/model/graphs/test_graph_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_graph_all_branches(graph):
4444
graph.add_node(2)
4545
graph.add_branch(1, 2)
4646

47-
assert [{1, 2}] == graph.all_branches
47+
assert [(1, 2)] == list(graph.all_branches)
4848

4949

5050
def test_graph_all_branches_parallel(graph):
@@ -54,7 +54,7 @@ def test_graph_all_branches_parallel(graph):
5454
graph.add_branch(1, 2)
5555
graph.add_branch(2, 1)
5656

57-
assert [{1, 2}, {1, 2}, {1, 2}] == graph.all_branches
57+
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)
5858

5959

6060
def test_graph_delete_branch(graph):

0 commit comments

Comments
 (0)