Skip to content

Commit 63db90b

Browse files
committed
feat(gfql ast): type edge by direction
1 parent fdeca4d commit 63db90b

File tree

2 files changed

+81
-4
lines changed

2 files changed

+81
-4
lines changed

graphistry/compute/ast.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,22 @@ def __init__(self,
400400
edge_query=edge_query
401401
)
402402

403+
@classmethod
404+
def from_json(cls, d: dict) -> 'ASTEdge':
405+
out = ASTEdgeForward(
406+
edge_match=maybe_filter_dict_from_json(d, 'edge_match'),
407+
hops=d['hops'] if 'hops' in d else None,
408+
to_fixed_point=d['to_fixed_point'] if 'to_fixed_point' in d else DEFAULT_FIXED_POINT,
409+
source_node_match=maybe_filter_dict_from_json(d, 'source_node_match'),
410+
destination_node_match=maybe_filter_dict_from_json(d, 'destination_node_match'),
411+
source_node_query=d['source_node_query'] if 'source_node_query' in d else None,
412+
destination_node_query=d['destination_node_query'] if 'destination_node_query' in d else None,
413+
edge_query=d['edge_query'] if 'edge_query' in d else None,
414+
name=d['name'] if 'name' in d else None
415+
)
416+
out.validate()
417+
return out
418+
403419
e_forward = ASTEdgeForward # noqa: E305
404420

405421
class ASTEdgeReverse(ASTEdge):
@@ -430,6 +446,22 @@ def __init__(self,
430446
edge_query=edge_query
431447
)
432448

449+
@classmethod
450+
def from_json(cls, d: dict) -> 'ASTEdge':
451+
out = ASTEdgeReverse(
452+
edge_match=maybe_filter_dict_from_json(d, 'edge_match'),
453+
hops=d['hops'] if 'hops' in d else None,
454+
to_fixed_point=d['to_fixed_point'] if 'to_fixed_point' in d else DEFAULT_FIXED_POINT,
455+
source_node_match=maybe_filter_dict_from_json(d, 'source_node_match'),
456+
destination_node_match=maybe_filter_dict_from_json(d, 'destination_node_match'),
457+
source_node_query=d['source_node_query'] if 'source_node_query' in d else None,
458+
destination_node_query=d['destination_node_query'] if 'destination_node_query' in d else None,
459+
edge_query=d['edge_query'] if 'edge_query' in d else None,
460+
name=d['name'] if 'name' in d else None
461+
)
462+
out.validate()
463+
return out
464+
433465
e_reverse = ASTEdgeReverse # noqa: E305
434466

435467
class ASTEdgeUndirected(ASTEdge):
@@ -460,6 +492,22 @@ def __init__(self,
460492
edge_query=edge_query
461493
)
462494

495+
@classmethod
496+
def from_json(cls, d: dict) -> 'ASTEdge':
497+
out = ASTEdgeUndirected(
498+
edge_match=maybe_filter_dict_from_json(d, 'edge_match'),
499+
hops=d['hops'] if 'hops' in d else None,
500+
to_fixed_point=d['to_fixed_point'] if 'to_fixed_point' in d else DEFAULT_FIXED_POINT,
501+
source_node_match=maybe_filter_dict_from_json(d, 'source_node_match'),
502+
destination_node_match=maybe_filter_dict_from_json(d, 'destination_node_match'),
503+
source_node_query=d['source_node_query'] if 'source_node_query' in d else None,
504+
destination_node_query=d['destination_node_query'] if 'destination_node_query' in d else None,
505+
edge_query=d['edge_query'] if 'edge_query' in d else None,
506+
name=d['name'] if 'name' in d else None
507+
)
508+
out.validate()
509+
return out
510+
463511
e_undirected = ASTEdgeUndirected # noqa: E305
464512
e = ASTEdgeUndirected # noqa: E305
465513

@@ -472,7 +520,17 @@ def from_json(o: JSONVal) -> Union[ASTNode, ASTEdge]:
472520
if o['type'] == 'Node':
473521
out = ASTNode.from_json(o)
474522
elif o['type'] == 'Edge':
475-
out = ASTEdge.from_json(o)
523+
if 'direction' in o:
524+
if o['direction'] == 'forward':
525+
out = ASTEdgeForward.from_json(o)
526+
elif o['direction'] == 'reverse':
527+
out = ASTEdgeReverse.from_json(o)
528+
elif o['direction'] == 'undirected':
529+
out = ASTEdgeUndirected.from_json(o)
530+
else:
531+
raise ValueError(f'Edge has unknown direction {o["direction"]}')
532+
else:
533+
raise ValueError('Edge missing direction')
476534
else:
477535
raise ValueError(f'Unknown type {o["type"]}')
478536
return out

graphistry/tests/compute/test_chain.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
22
import pandas as pd
3-
from graphistry.compute.predicates.is_in import is_in
4-
from graphistry.compute.predicates.numeric import gt
53
import pytest
64

7-
from graphistry.compute.ast import ASTNode, ASTEdge, n, e, e_undirected, e_forward
5+
from graphistry.compute.ast import ASTEdgeUndirected, ASTNode, ASTEdge, n, e, e_undirected, e_forward
86
from graphistry.compute.chain import Chain
7+
from graphistry.compute.predicates.is_in import IsIn, is_in
8+
from graphistry.compute.predicates.numeric import gt
99
from graphistry.tests.test_compute import CGFull
1010

1111

@@ -298,6 +298,25 @@ def test_chain_serialization_pred():
298298
o2 = d.to_json()
299299
assert o == o2
300300

301+
def test_chain_serialize_pred_is_in():
302+
303+
#from graphistry.compute.chain import Chain
304+
#from graphistry import e_undirected, is_in
305+
o = Chain([
306+
e_undirected(
307+
hops=1,
308+
edge_match={"source": is_in(options=[
309+
"Oakville Square",
310+
"Maplewood Square"
311+
])})
312+
]).to_json()
313+
d = Chain.from_json(o)
314+
assert isinstance(d.chain[0], ASTEdgeUndirected), f'got: {type(d.chain[0])}'
315+
assert d.chain[0].direction == 'undirected'
316+
assert d.chain[0].hops == 1
317+
assert isinstance(d.chain[0].edge_match['source'], IsIn)
318+
assert d.chain[0].edge_match['source'].options == ['Oakville Square', 'Maplewood Square']
319+
301320
def test_chain_simple_cudf_pd():
302321
nodes_df = pd.DataFrame({'id': [0, 1, 2], 'label': ['a', 'b', 'c']})
303322
edges_df = pd.DataFrame({'src': [0, 1, 2], 'dst': [1, 2, 0]})

0 commit comments

Comments
 (0)