12
12
https://github.com/stdbrouw/django-treebeard-dag
13
13
"""
14
14
15
+ from django .apps import apps
15
16
from django .db import models , connection
16
17
from django .db .models import Case , When
17
18
from django .core .exceptions import ValidationError
18
19
20
+ LIMITING_CLAUSE_1 = """AND second.{fk_field_name}_id = {fk_id}"""
21
+ LIMITING_CLAUSE_2 = """WHERE {relationship_table}.{fk_field_name}_id = {fk_id}"""
19
22
20
23
ANCESTOR_QUERY = """
21
24
WITH RECURSIVE traverse(id, depth) AS (
24
27
LEFT OUTER JOIN {relationship_table} AS second
25
28
ON first.parent_id = second.child_id
26
29
WHERE first.child_id = %(id)s
30
+ {limiting_clause_1}
27
31
UNION
28
32
SELECT DISTINCT parent_id, traverse.depth + 1
29
33
FROM traverse
30
34
INNER JOIN {relationship_table}
31
35
ON {relationship_table}.child_id = traverse.id
36
+ {limiting_clause_2}
32
37
)
33
38
SELECT id FROM traverse
34
39
GROUP BY id
42
47
LEFT OUTER JOIN {relationship_table} AS second
43
48
ON first.child_id = second.parent_id
44
49
WHERE first.parent_id = %(id)s
50
+ {limiting_clause_1}
45
51
UNION
46
52
SELECT DISTINCT child_id, traverse.depth + 1
47
53
FROM traverse
48
54
INNER JOIN {relationship_table}
49
55
ON {relationship_table}.parent_id = traverse.id
56
+ {limiting_clause_2}
50
57
)
51
58
SELECT id FROM traverse
52
59
GROUP BY id
@@ -120,7 +127,7 @@ def _filter_order(queryset, field_names, values):
120
127
For instance
121
128
_filter_order(self.__class__.objects, "pk", ids)
122
129
returns a queryset of the current class, with instances where the 'pk' field matches an id in ids
123
-
130
+
124
131
"""
125
132
if not isinstance (field_names , list ):
126
133
field_names = [field_names ]
@@ -136,6 +143,17 @@ def _filter_order(queryset, field_names, values):
136
143
def node_factory (edge_model , children_null = True , base_model = models .Model ):
137
144
edge_model_table = edge_model ._meta .db_table
138
145
146
+ def get_foreign_key_field (instance = None ):
147
+ """
148
+ Provided a model instance, checks if the edge model has a ForeignKey field to the model for that instance, and then returns the field name and instance id
149
+ """
150
+ if instance is not None :
151
+ for field in edge_model ._meta .get_fields ():
152
+ if field .related_model is instance ._meta .model :
153
+ # Return the first field that matches
154
+ return (field .name , instance .id )
155
+ return (None , None )
156
+
139
157
class Node (base_model ):
140
158
children = models .ManyToManyField (
141
159
"self" ,
@@ -169,68 +187,118 @@ def filter_order_ids(self, ids):
169
187
"""
170
188
return _filter_order (self .__class__ .objects , "pk" , ids )
171
189
172
- def ancestors_ids (self ):
190
+ def ancestors_ids (self , limiting_instance = None ):
191
+ fk_field_name , fk_value = get_foreign_key_field (limiting_instance )
192
+ if fk_field_name is not None and fk_value is not None :
193
+ limiting_clause_1 = LIMITING_CLAUSE_1 .format (
194
+ fk_field_name = fk_field_name , fk_value = fk_value
195
+ )
196
+ limiting_clause_2 = LIMITING_CLAUSE_2 .format (
197
+ relationship_table = edge_model_table ,
198
+ fk_field_name = fk_field_name ,
199
+ fk_value = fk_value ,
200
+ )
201
+ else :
202
+ limiting_clause_1 , limiting_clause_2 = ("" , "" )
203
+
173
204
with connection .cursor () as cursor :
174
205
cursor .execute (
175
- ANCESTOR_QUERY .format (relationship_table = edge_model_table ),
206
+ ANCESTOR_QUERY .format (
207
+ relationship_table = edge_model_table ,
208
+ limiting_clause_1 = limiting_clause_1 ,
209
+ limiting_clause_2 = limiting_clause_2 ,
210
+ ),
176
211
{"id" : self .id },
177
212
)
178
213
return [row [0 ] for row in cursor .fetchall ()]
179
214
180
- def ancestors_and_self_ids (self ):
181
- return self .ancestors_ids () + [self .id ]
215
+ def ancestors_and_self_ids (self , limiting_instance = None ):
216
+ return self .ancestors_ids (limiting_instance = limiting_instance ) + [self .id ]
217
+
218
+ def self_and_ancestors_ids (self , limiting_instance = None ):
219
+ return self .ancestors_and_self_ids (limiting_instance = limiting_instance )[
220
+ ::- 1
221
+ ]
182
222
183
- def self_and_ancestors_ids (self ):
184
- return self .ancestors_and_self_ids ()[::- 1 ]
223
+ def ancestors (self , limiting_instance = None ):
224
+ return self .filter_order_ids (
225
+ self .ancestors_ids (limiting_instance = limiting_instance )
226
+ )
185
227
186
- def ancestors (self ):
187
- return self .filter_order_ids (self .ancestors_ids ())
228
+ def ancestors_and_self (self , limiting_instance = None ):
229
+ return self .filter_order_ids (
230
+ self .ancestors_and_self_ids (limiting_instance = limiting_instance )
231
+ )
188
232
189
- def ancestors_and_self (self ):
190
- return self .filter_order_ids ( self . ancestors_and_self_ids ())
233
+ def self_and_ancestors (self , limiting_instance = None ):
234
+ return self .ancestors_and_self ( limiting_instance = limiting_instance )[:: - 1 ]
191
235
192
- def self_and_ancestors (self ):
193
- return self .ancestors_and_self ()[::- 1 ]
236
+ def descendants_ids (self , limiting_instance = None ):
237
+ fk_field_name , fk_value = get_foreign_key_field (limiting_instance )
238
+ if fk_field_name is not None and fk_value is not None :
239
+ limiting_clause_1 = LIMITING_CLAUSE_1 .format (
240
+ fk_field_name = fk_field_name , fk_value = fk_value
241
+ )
242
+ limiting_clause_2 = LIMITING_CLAUSE_2 .format (
243
+ relationship_table = edge_model_table ,
244
+ fk_field_name = fk_field_name ,
245
+ fk_value = fk_value ,
246
+ )
247
+ else :
248
+ limiting_clause_1 , limiting_clause_2 = ("" , "" )
194
249
195
- def descendants_ids (self ):
196
250
with connection .cursor () as cursor :
197
251
cursor .execute (
198
- DESCENDANT_QUERY .format (relationship_table = edge_model_table ),
252
+ DESCENDANT_QUERY .format (
253
+ relationship_table = edge_model_table ,
254
+ limiting_clause_1 = limiting_clause_1 ,
255
+ limiting_clause_2 = limiting_clause_2 ,
256
+ ),
199
257
{"id" : self .id },
200
258
)
201
259
return [row [0 ] for row in cursor .fetchall ()]
202
260
203
- def self_and_descendants_ids (self ):
204
- return [self .id ] + self .descendants_ids ()
261
+ def self_and_descendants_ids (self , limiting_instance = None ):
262
+ return [self .id ] + self .descendants_ids (limiting_instance = limiting_instance )
205
263
206
- def descendants_and_self_ids (self ):
207
- return self .self_and_descendants_ids ()[::- 1 ]
264
+ def descendants_and_self_ids (self , limiting_instance = None ):
265
+ return self .self_and_descendants_ids (limiting_instance = limiting_instance )[
266
+ ::- 1
267
+ ]
208
268
209
- def descendants (self ):
210
- return self .filter_order_ids (self .descendants_ids ())
269
+ def descendants (self , limiting_instance = None ):
270
+ return self .filter_order_ids (
271
+ self .descendants_ids (limiting_instance = limiting_instance )
272
+ )
211
273
212
- def self_and_descendants (self ):
213
- return self .filter_order_ids (self .self_and_descendants_ids ())
274
+ def self_and_descendants (self , limiting_instance = None ):
275
+ return self .filter_order_ids (
276
+ self .self_and_descendants_ids (limiting_instance = limiting_instance )
277
+ )
214
278
215
- def descendants_and_self (self ):
216
- return self .self_and_descendants ()[::- 1 ]
279
+ def descendants_and_self (self , limiting_instance = None ):
280
+ return self .self_and_descendants (limiting_instance = limiting_instance )[::- 1 ]
217
281
218
- def clan_ids (self ):
282
+ def clan_ids (self , limiting_instance = None ):
219
283
"""
220
284
Returns a list of ids with all ancestors, self, and all descendants
221
285
"""
222
- return self .ancestors_ids () + self .self_and_descendants_ids ()
286
+ return self .ancestors_ids (
287
+ limiting_instance = limiting_instance
288
+ ) + self .self_and_descendants_ids (limiting_instance = limiting_instance )
223
289
224
- def clan (self ):
290
+ def clan (self , limiting_instance = None ):
225
291
"""
226
292
Returns a queryset with all ancestors, self, and all descendants
227
293
"""
228
- return self .filter_order_ids (self .clan_ids ())
294
+ return self .filter_order_ids (
295
+ self .clan_ids (limiting_instance = limiting_instance )
296
+ )
229
297
230
298
def descendants_edges_ids (self , cached_results = None ):
231
299
"""
232
300
Returns a set of descendants edges
233
- # ToDo: Modify to use CTE
301
+ # ToDo: Modify to use CTE and sort topologically
234
302
"""
235
303
if cached_results is None :
236
304
cached_results = dict ()
@@ -240,7 +308,9 @@ def descendants_edges_ids(self, cached_results=None):
240
308
edge_set = set ()
241
309
for f in self .children .all ():
242
310
edge_set .add (edge_model .objects .get (parent = self .id , child = f .id ).id )
243
- edge_set .update (f .descendants_edges_ids (cached_results = cached_results ))
311
+ edge_set .update (
312
+ f .descendants_edges_ids (cached_results = cached_results )
313
+ )
244
314
cached_results [self .id ] = edge_set
245
315
return edge_set
246
316
@@ -253,7 +323,7 @@ def descendants_edges(self):
253
323
def ancestors_edges_ids (self , cached_results = None ):
254
324
"""
255
325
Returns a set of ancestors edges
256
- # ToDo: Modify to use CTE
326
+ # ToDo: Modify to use CTE and sort topologically
257
327
"""
258
328
if cached_results is None :
259
329
cached_results = dict ()
@@ -263,7 +333,9 @@ def ancestors_edges_ids(self, cached_results=None):
263
333
edge_set = set ()
264
334
for f in self .parents .all ():
265
335
edge_set .add (edge_model .objects .get (child = self .id , parent = f .id ).id )
266
- edge_set .update (f .ancestors_edges_ids (cached_results = cached_results ))
336
+ edge_set .update (
337
+ f .ancestors_edges_ids (cached_results = cached_results )
338
+ )
267
339
cached_results [self .id ] = edge_set
268
340
return edge_set
269
341
@@ -292,7 +364,7 @@ def path_ids_list(
292
364
self , target_node , directional = True , max_depth = 20 , max_paths = 1
293
365
):
294
366
"""
295
- Returns a list of paths from self to target node, optionally in either
367
+ Returns a list of paths from self to target node, optionally in either
296
368
direction. The resulting lists are always sorted from root-side, toward
297
369
leaf-side, regardless of the relative position of starting and ending nodes.
298
370
@@ -337,14 +409,23 @@ def shortest_path(self, target_node, directional=True, max_depth=20):
337
409
Returns a queryset of the shortest path
338
410
"""
339
411
return self .filter_order_ids (
340
- self .path_ids_list (target_node , directional = directional , max_depth = max_depth )[0 ]
412
+ self .path_ids_list (
413
+ target_node , directional = directional , max_depth = max_depth
414
+ )[0 ]
341
415
)
342
416
343
417
def distance (self , target_node , directional = True , max_depth = 20 ):
344
418
"""
345
419
Returns the shortest hops count to the target node
346
420
"""
347
- return len (self .path_ids_list (target_node , directional = directional , max_depth = max_depth )[0 ]) - 1
421
+ return (
422
+ len (
423
+ self .path_ids_list (
424
+ target_node , directional = directional , max_depth = max_depth
425
+ )[0 ]
426
+ )
427
+ - 1
428
+ )
348
429
349
430
def is_root (self ):
350
431
"""
0 commit comments