Skip to content

Feature: add tmp_remove_nodes method to graph #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/power_grid_model_ds/_core/model/graphs/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MPL-2.0

from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Generator

import numpy as np
Expand Down Expand Up @@ -72,6 +73,14 @@ def has_node(self, node_id: int) -> bool:

return self._has_node(node_id=internal_node_id)

def in_branches(self, node_id: int) -> Generator[tuple[int, int], None, None]:
"""Return all branches that have the node as an endpoint."""
int_node_id = self.external_to_internal(node_id)
internal_edges = self._in_branches(int_node_id=int_node_id)
return (
(self.internal_to_external(source), self.internal_to_external(target)) for source, target in internal_edges
)

def add_node(self, ext_node_id: int, raise_on_fail: bool = True) -> None:
"""Add a node to the graph."""
if self.has_node(ext_node_id):
Expand Down Expand Up @@ -173,6 +182,28 @@ def delete_branch3_array(self, branch3_array: Branch3Array, raise_on_fail: bool
branches = _get_branch3_branches(branch3)
self.delete_branch_array(branches, raise_on_fail=raise_on_fail)

@contextmanager
def tmp_remove_nodes(self, nodes: list[int]) -> Generator:
"""Context manager that temporarily removes nodes and their branches from the graph.
Example:
>>> with graph.tmp_remove_nodes([1, 2, 3]):
>>> assert not graph.has_node(1)
>>> assert graph.has_node(1)
In practice, this is useful when you want to e.g. calculate the shortest path between two nodes without
considering certain nodes.
"""
edge_list = []
for node in nodes:
edge_list += list(self.in_branches(node))
self.delete_node(node)

yield

for node in nodes:
self.add_node(node)
for source, target in edge_list:
self.add_branch(source, target)

def get_shortest_path(self, ext_start_node_id: int, ext_end_node_id: int) -> tuple[list[int], int]:
"""Calculate the shortest path between two nodes

Expand Down Expand Up @@ -279,6 +310,9 @@ def _branch_is_relevant(self, branch: BranchArray) -> bool:
return branch.is_active.item()
return True

@abstractmethod
def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]: ...

@abstractmethod
def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bool = False) -> list[int]: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def _get_connected(self, node_id: int, nodes_to_ignore: list[int], inclusive: bo

return connected_nodes

def _in_branches(self, int_node_id: int) -> Generator[tuple[int, int], None, None]:
return ((source, target) for source, target, _ in self._graph.in_edges(int_node_id))

def _find_fundamental_cycles(self) -> list[list[int]]:
"""Find all fundamental cycles in the graph using Rustworkx.

Expand Down
40 changes: 40 additions & 0 deletions tests/unit/model/graphs/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"""Grid tests"""

from collections import Counter

import numpy as np
import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -55,6 +57,17 @@ def test_graph_all_branches_parallel(graph):
assert [(1, 2), (1, 2), (2, 1)] == list(graph.all_branches)


def test_graph_in_branches(graph):
graph.add_node(1)
graph.add_node(2)
graph.add_branch(1, 2)
graph.add_branch(1, 2)
graph.add_branch(2, 1)

assert [(2, 1), (2, 1), (2, 1)] == list(graph.in_branches(1))
assert [(1, 2), (1, 2), (1, 2)] == list(graph.in_branches(2))


def test_graph_delete_branch(graph):
"""Test whether a branch is deleted correctly"""
graph.add_node(1)
Expand Down Expand Up @@ -338,3 +351,30 @@ def test_get_connected_ignore_multiple_nodes(self, graph_with_2_routes):
connected_nodes = graph.get_connected(node_id=1, nodes_to_ignore=[2, 4])

assert {5} == set(connected_nodes)


def test_tmp_remove_nodes(graph_with_2_routes) -> None:
graph = graph_with_2_routes

assert graph.nr_branches == 4

# add parallel branches to test whether they are restored correctly
graph.add_branch(1, 5)
graph.add_branch(5, 1)

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

before_sets = [frozenset(branch) for branch in graph.all_branches]
counter_before = Counter(before_sets)

with graph.tmp_remove_nodes([1, 2]):
assert graph.nr_nodes == 3
assert list(graph.all_branches) == [(5, 4)]

assert graph.nr_nodes == 5
assert graph.nr_branches == 6

after_sets = [frozenset(branch) for branch in graph.all_branches]
counter_after = Counter(after_sets)
assert counter_before == counter_after