Skip to content

Commit 99e7c9f

Browse files
Fixed #3 to allow nodes with multiple shared edges to delete each other
1 parent aa4031d commit 99e7c9f

File tree

3 files changed

+58
-13
lines changed

3 files changed

+58
-13
lines changed

django_postgresql_dag/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def add_child(self, child, **kwargs):
8686
def remove_child(self, child, delete_node=False):
8787
"""Removes the edge connecting this node to child, and optionally deletes the child node as well"""
8888
if child in self.children.all():
89-
self.children.through.objects.get(parent=self, child=child).delete()
89+
self.children.through.objects.filter(parent=self, child=child).delete()
9090
if delete_node:
9191
# Note: Per django docs:
9292
# https://docs.djangoproject.com/en/dev/ref/models/instances/#deleting-objects
@@ -100,7 +100,7 @@ def add_parent(self, parent, *args, **kwargs):
100100
def remove_parent(self, parent, delete_node=False):
101101
"""Removes the edge connecting this node to parent, and optionally deletes the parent node as well"""
102102
if parent in self.parents.all():
103-
parent.children.through.objects.get(parent=parent, child=self).delete()
103+
parent.children.through.objects.filter(parent=parent, child=self).delete()
104104
if delete_node:
105105
# Note: Per django docs:
106106
# https://docs.djangoproject.com/en/dev/ref/models/instances/#deleting-objects

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from setuptools import setup
55

6-
version = '0.1.7'
6+
version = '0.1.8'
77

88
classifiers = [
99
"Development Status :: 3 - Alpha",

tests/test.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_02_dag(self):
193193
)
194194

195195
log.debug("get_leaves")
196-
self.assertEqual([p.name for p in root.leaves()], ["b2", "c1", "c2", "b1"])
196+
self.assertEqual(set([p.name for p in root.leaves()]), set(["b2", "c1", "c2", "b1"]))
197197
log.debug("get_roots")
198198
self.assertEqual([p.name for p in c2.roots()], ["root"])
199199

@@ -493,7 +493,51 @@ def test_02_dag(self):
493493
log.debug(f"Node count: {NetworkNode.objects.count()}")
494494
log.debug(f"Edge count: {NetworkEdge.objects.count()}")
495495

496-
def test_03_deep_dag(self):
496+
def test_03_multilinked_nodes(self):
497+
log = logging.getLogger("test_03")
498+
log.debug("Test deletion of nodes two nodes with multiple shared edges")
499+
500+
shared_edge_count = 5
501+
502+
def create_multilinked_nodes(shared_edge_count):
503+
log.debug("Creating multiple links between a parent and child node")
504+
child_node = NetworkNode.objects.create()
505+
parent_node = NetworkNode.objects.create()
506+
507+
# Call this multiple times to create multiple edges between same parent/child
508+
for _ in range(shared_edge_count):
509+
child_node.add_parent(parent_node)
510+
511+
return child_node, parent_node
512+
513+
def delete_parents():
514+
child_node, parent_node = create_multilinked_nodes(shared_edge_count)
515+
516+
# Refresh the related manager
517+
child_node.refresh_from_db()
518+
519+
self.assertEqual(child_node.parents.count(), shared_edge_count)
520+
log.debug(f"Initial parents count: {child_node.parents.count()}")
521+
child_node.remove_parent(parent_node)
522+
self.assertEqual(child_node.parents.count(), 0)
523+
log.debug(f"Final parents count: {child_node.parents.count()}")
524+
525+
def delete_children():
526+
child_node, parent_node = create_multilinked_nodes(shared_edge_count)
527+
528+
# Refresh the related manager
529+
parent_node.refresh_from_db()
530+
531+
self.assertEqual(parent_node.children.count(), shared_edge_count)
532+
log.debug(f"Initial children count: {parent_node.children.count()}")
533+
parent_node.remove_child(child_node)
534+
self.assertEqual(parent_node.children.count(), 0)
535+
log.debug(f"Final children count: {parent_node.children.count()}")
536+
537+
delete_parents()
538+
delete_children()
539+
540+
def test_04_deep_dag(self):
497541
"""
498542
Create a deep graph and check that graph operations run in a
499543
reasonable amount of time (linear in size of graph, not
@@ -503,11 +547,10 @@ def test_03_deep_dag(self):
503547
def run_test():
504548
# Using the graph generation algorithm below, the number of potential
505549
# paths from node 0 doubles for each increase in n.
506-
# number_of_paths = 2^(n-1) WRONG!!!
507-
# When n=22, there are on the order of 1 million paths through the graph
508-
# from node 0, so results for intermediate nodes need to be cached
550+
# When n=22, there are many paths through the graph from node 0,
551+
# so results for intermediate nodes need to be cached
509552

510-
log = logging.getLogger("test_03")
553+
log = logging.getLogger("test_04")
511554

512555
n = 22 # Keep it an even number
513556

@@ -547,8 +590,9 @@ def run_test():
547590
first = NetworkNode.objects.get(name="0")
548591
last = NetworkNode.objects.get(name=str(2 * n - 1))
549592

550-
log.debug(f"Path exists: {first.path_exists(last, max_depth=n)}")
551-
self.assertTrue(first.path_exists(last, max_depth=n), True)
593+
path_exists = first.path_exists(last, max_depth=n)
594+
log.debug(f"Path exists: {path_exists}")
595+
self.assertTrue(path_exists, True)
552596
self.assertEqual(first.distance(last, max_depth=n), n - 1)
553597

554598
log.debug(f"Node count: {NetworkNode.objects.count()}")
@@ -560,8 +604,9 @@ def run_test():
560604
)
561605

562606
middle = NetworkNode.objects.get(pk=n - 1)
563-
log.debug("Distance")
564-
self.assertEqual(first.distance(middle, max_depth=n), n / 2 - 1)
607+
distance = first.distance(middle, max_depth=n)
608+
log.debug(f"Distance: {distance}")
609+
self.assertEqual(distance, n / 2 - 1)
565610

566611
# Run the test, raising an error if the code times out
567612
p = multiprocessing.Process(target=run_test)

0 commit comments

Comments
 (0)