Skip to content

Commit 089e76a

Browse files
Added ability to limit the area of the graph the CTE will search
1 parent 3b2fab6 commit 089e76a

File tree

1 file changed

+117
-36
lines changed

1 file changed

+117
-36
lines changed

django_postgresql_dag/models.py

Lines changed: 117 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212
https://github.com/stdbrouw/django-treebeard-dag
1313
"""
1414

15+
from django.apps import apps
1516
from django.db import models, connection
1617
from django.db.models import Case, When
1718
from django.core.exceptions import ValidationError
1819

20+
LIMITING_CLAUSE_1 = """AND second.{fk_field_name}_id = {fk_id}"""
21+
LIMITING_CLAUSE_2 = """WHERE {relationship_table}.{fk_field_name}_id = {fk_id}"""
1922

2023
ANCESTOR_QUERY = """
2124
WITH RECURSIVE traverse(id, depth) AS (
@@ -24,11 +27,13 @@
2427
LEFT OUTER JOIN {relationship_table} AS second
2528
ON first.parent_id = second.child_id
2629
WHERE first.child_id = %(id)s
30+
{limiting_clause_1}
2731
UNION
2832
SELECT DISTINCT parent_id, traverse.depth + 1
2933
FROM traverse
3034
INNER JOIN {relationship_table}
3135
ON {relationship_table}.child_id = traverse.id
36+
{limiting_clause_2}
3237
)
3338
SELECT id FROM traverse
3439
GROUP BY id
@@ -42,11 +47,13 @@
4247
LEFT OUTER JOIN {relationship_table} AS second
4348
ON first.child_id = second.parent_id
4449
WHERE first.parent_id = %(id)s
50+
{limiting_clause_1}
4551
UNION
4652
SELECT DISTINCT child_id, traverse.depth + 1
4753
FROM traverse
4854
INNER JOIN {relationship_table}
4955
ON {relationship_table}.parent_id = traverse.id
56+
{limiting_clause_2}
5057
)
5158
SELECT id FROM traverse
5259
GROUP BY id
@@ -120,7 +127,7 @@ def _filter_order(queryset, field_names, values):
120127
For instance
121128
_filter_order(self.__class__.objects, "pk", ids)
122129
returns a queryset of the current class, with instances where the 'pk' field matches an id in ids
123-
130+
124131
"""
125132
if not isinstance(field_names, list):
126133
field_names = [field_names]
@@ -136,6 +143,17 @@ def _filter_order(queryset, field_names, values):
136143
def node_factory(edge_model, children_null=True, base_model=models.Model):
137144
edge_model_table = edge_model._meta.db_table
138145

146+
def get_foreign_key_field(instance=None):
147+
"""
148+
Provided a model instance, checks if the edge model has a ForeignKey field to the model for that instance, and then returns the field name and instance id
149+
"""
150+
if instance is not None:
151+
for field in edge_model._meta.get_fields():
152+
if field.related_model is instance._meta.model:
153+
# Return the first field that matches
154+
return (field.name, instance.id)
155+
return (None, None)
156+
139157
class Node(base_model):
140158
children = models.ManyToManyField(
141159
"self",
@@ -169,68 +187,118 @@ def filter_order_ids(self, ids):
169187
"""
170188
return _filter_order(self.__class__.objects, "pk", ids)
171189

172-
def ancestors_ids(self):
190+
def ancestors_ids(self, limiting_instance=None):
191+
fk_field_name, fk_value = get_foreign_key_field(limiting_instance)
192+
if fk_field_name is not None and fk_value is not None:
193+
limiting_clause_1 = LIMITING_CLAUSE_1.format(
194+
fk_field_name=fk_field_name, fk_value=fk_value
195+
)
196+
limiting_clause_2 = LIMITING_CLAUSE_2.format(
197+
relationship_table=edge_model_table,
198+
fk_field_name=fk_field_name,
199+
fk_value=fk_value,
200+
)
201+
else:
202+
limiting_clause_1, limiting_clause_2 = ("", "")
203+
173204
with connection.cursor() as cursor:
174205
cursor.execute(
175-
ANCESTOR_QUERY.format(relationship_table=edge_model_table),
206+
ANCESTOR_QUERY.format(
207+
relationship_table=edge_model_table,
208+
limiting_clause_1=limiting_clause_1,
209+
limiting_clause_2=limiting_clause_2,
210+
),
176211
{"id": self.id},
177212
)
178213
return [row[0] for row in cursor.fetchall()]
179214

180-
def ancestors_and_self_ids(self):
181-
return self.ancestors_ids() + [self.id]
215+
def ancestors_and_self_ids(self, limiting_instance=None):
216+
return self.ancestors_ids(limiting_instance=limiting_instance) + [self.id]
217+
218+
def self_and_ancestors_ids(self, limiting_instance=None):
219+
return self.ancestors_and_self_ids(limiting_instance=limiting_instance)[
220+
::-1
221+
]
182222

183-
def self_and_ancestors_ids(self):
184-
return self.ancestors_and_self_ids()[::-1]
223+
def ancestors(self, limiting_instance=None):
224+
return self.filter_order_ids(
225+
self.ancestors_ids(limiting_instance=limiting_instance)
226+
)
185227

186-
def ancestors(self):
187-
return self.filter_order_ids(self.ancestors_ids())
228+
def ancestors_and_self(self, limiting_instance=None):
229+
return self.filter_order_ids(
230+
self.ancestors_and_self_ids(limiting_instance=limiting_instance)
231+
)
188232

189-
def ancestors_and_self(self):
190-
return self.filter_order_ids(self.ancestors_and_self_ids())
233+
def self_and_ancestors(self, limiting_instance=None):
234+
return self.ancestors_and_self(limiting_instance=limiting_instance)[::-1]
191235

192-
def self_and_ancestors(self):
193-
return self.ancestors_and_self()[::-1]
236+
def descendants_ids(self, limiting_instance=None):
237+
fk_field_name, fk_value = get_foreign_key_field(limiting_instance)
238+
if fk_field_name is not None and fk_value is not None:
239+
limiting_clause_1 = LIMITING_CLAUSE_1.format(
240+
fk_field_name=fk_field_name, fk_value=fk_value
241+
)
242+
limiting_clause_2 = LIMITING_CLAUSE_2.format(
243+
relationship_table=edge_model_table,
244+
fk_field_name=fk_field_name,
245+
fk_value=fk_value,
246+
)
247+
else:
248+
limiting_clause_1, limiting_clause_2 = ("", "")
194249

195-
def descendants_ids(self):
196250
with connection.cursor() as cursor:
197251
cursor.execute(
198-
DESCENDANT_QUERY.format(relationship_table=edge_model_table),
252+
DESCENDANT_QUERY.format(
253+
relationship_table=edge_model_table,
254+
limiting_clause_1=limiting_clause_1,
255+
limiting_clause_2=limiting_clause_2,
256+
),
199257
{"id": self.id},
200258
)
201259
return [row[0] for row in cursor.fetchall()]
202260

203-
def self_and_descendants_ids(self):
204-
return [self.id] + self.descendants_ids()
261+
def self_and_descendants_ids(self, limiting_instance=None):
262+
return [self.id] + self.descendants_ids(limiting_instance=limiting_instance)
205263

206-
def descendants_and_self_ids(self):
207-
return self.self_and_descendants_ids()[::-1]
264+
def descendants_and_self_ids(self, limiting_instance=None):
265+
return self.self_and_descendants_ids(limiting_instance=limiting_instance)[
266+
::-1
267+
]
208268

209-
def descendants(self):
210-
return self.filter_order_ids(self.descendants_ids())
269+
def descendants(self, limiting_instance=None):
270+
return self.filter_order_ids(
271+
self.descendants_ids(limiting_instance=limiting_instance)
272+
)
211273

212-
def self_and_descendants(self):
213-
return self.filter_order_ids(self.self_and_descendants_ids())
274+
def self_and_descendants(self, limiting_instance=None):
275+
return self.filter_order_ids(
276+
self.self_and_descendants_ids(limiting_instance=limiting_instance)
277+
)
214278

215-
def descendants_and_self(self):
216-
return self.self_and_descendants()[::-1]
279+
def descendants_and_self(self, limiting_instance=None):
280+
return self.self_and_descendants(limiting_instance=limiting_instance)[::-1]
217281

218-
def clan_ids(self):
282+
def clan_ids(self, limiting_instance=None):
219283
"""
220284
Returns a list of ids with all ancestors, self, and all descendants
221285
"""
222-
return self.ancestors_ids() + self.self_and_descendants_ids()
286+
return self.ancestors_ids(
287+
limiting_instance=limiting_instance
288+
) + self.self_and_descendants_ids(limiting_instance=limiting_instance)
223289

224-
def clan(self):
290+
def clan(self, limiting_instance=None):
225291
"""
226292
Returns a queryset with all ancestors, self, and all descendants
227293
"""
228-
return self.filter_order_ids(self.clan_ids())
294+
return self.filter_order_ids(
295+
self.clan_ids(limiting_instance=limiting_instance)
296+
)
229297

230298
def descendants_edges_ids(self, cached_results=None):
231299
"""
232300
Returns a set of descendants edges
233-
# ToDo: Modify to use CTE
301+
# ToDo: Modify to use CTE and sort topologically
234302
"""
235303
if cached_results is None:
236304
cached_results = dict()
@@ -240,7 +308,9 @@ def descendants_edges_ids(self, cached_results=None):
240308
edge_set = set()
241309
for f in self.children.all():
242310
edge_set.add(edge_model.objects.get(parent=self.id, child=f.id).id)
243-
edge_set.update(f.descendants_edges_ids(cached_results=cached_results))
311+
edge_set.update(
312+
f.descendants_edges_ids(cached_results=cached_results)
313+
)
244314
cached_results[self.id] = edge_set
245315
return edge_set
246316

@@ -253,7 +323,7 @@ def descendants_edges(self):
253323
def ancestors_edges_ids(self, cached_results=None):
254324
"""
255325
Returns a set of ancestors edges
256-
# ToDo: Modify to use CTE
326+
# ToDo: Modify to use CTE and sort topologically
257327
"""
258328
if cached_results is None:
259329
cached_results = dict()
@@ -263,7 +333,9 @@ def ancestors_edges_ids(self, cached_results=None):
263333
edge_set = set()
264334
for f in self.parents.all():
265335
edge_set.add(edge_model.objects.get(child=self.id, parent=f.id).id)
266-
edge_set.update(f.ancestors_edges_ids(cached_results=cached_results))
336+
edge_set.update(
337+
f.ancestors_edges_ids(cached_results=cached_results)
338+
)
267339
cached_results[self.id] = edge_set
268340
return edge_set
269341

@@ -292,7 +364,7 @@ def path_ids_list(
292364
self, target_node, directional=True, max_depth=20, max_paths=1
293365
):
294366
"""
295-
Returns a list of paths from self to target node, optionally in either
367+
Returns a list of paths from self to target node, optionally in either
296368
direction. The resulting lists are always sorted from root-side, toward
297369
leaf-side, regardless of the relative position of starting and ending nodes.
298370
@@ -337,14 +409,23 @@ def shortest_path(self, target_node, directional=True, max_depth=20):
337409
Returns a queryset of the shortest path
338410
"""
339411
return self.filter_order_ids(
340-
self.path_ids_list(target_node, directional=directional, max_depth=max_depth)[0]
412+
self.path_ids_list(
413+
target_node, directional=directional, max_depth=max_depth
414+
)[0]
341415
)
342416

343417
def distance(self, target_node, directional=True, max_depth=20):
344418
"""
345419
Returns the shortest hops count to the target node
346420
"""
347-
return len(self.path_ids_list(target_node, directional=directional, max_depth=max_depth)[0]) - 1
421+
return (
422+
len(
423+
self.path_ids_list(
424+
target_node, directional=directional, max_depth=max_depth
425+
)[0]
426+
)
427+
- 1
428+
)
348429

349430
def is_root(self):
350431
"""

0 commit comments

Comments
 (0)