Skip to content

Commit fbdcfd4

Browse files
authored
New feature: change_layout method for Graph mobject (#945)
* factor out code for determining the graph layout * add change_layout method
1 parent b87bdd9 commit fbdcfd4

File tree

1 file changed

+123
-71
lines changed

1 file changed

+123
-71
lines changed

manim/mobject/graph.py

Lines changed: 123 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,81 @@
1616
import numpy as np
1717

1818

19+
def _determine_graph_layout(
20+
nx_graph: nx.classes.graph.Graph,
21+
layout: Union[str, dict] = "spring",
22+
layout_scale: float = 2,
23+
layout_config: Union[dict, None] = None,
24+
partitions: Union[List[List[Hashable]], None] = None,
25+
root_vertex: Union[Hashable, None] = None,
26+
) -> dict:
27+
automatic_layouts = {
28+
"circular": nx.layout.circular_layout,
29+
"kamada_kawai": nx.layout.kamada_kawai_layout,
30+
"planar": nx.layout.planar_layout,
31+
"random": nx.layout.random_layout,
32+
"shell": nx.layout.shell_layout,
33+
"spectral": nx.layout.spectral_layout,
34+
"partite": nx.layout.multipartite_layout,
35+
"tree": _tree_layout,
36+
"spiral": nx.layout.spiral_layout,
37+
"spring": nx.layout.spring_layout,
38+
}
39+
40+
custom_layouts = ["random", "partite", "tree"]
41+
42+
if layout_config is None:
43+
layout_config = {}
44+
45+
if isinstance(layout, dict):
46+
return layout
47+
elif layout in automatic_layouts and layout not in custom_layouts:
48+
auto_layout = automatic_layouts[layout](
49+
nx_graph, scale=layout_scale, **layout_config
50+
)
51+
return dict([(k, np.append(v, [0])) for k, v in auto_layout.items()])
52+
elif layout == "tree":
53+
return _tree_layout(
54+
nx_graph,
55+
root_vertex=root_vertex,
56+
scale=layout_scale,
57+
)
58+
elif layout == "partite":
59+
if partitions is None or len(partitions) == 0:
60+
raise ValueError(
61+
"The partite layout requires the 'partitions' parameter to contain the partition of the vertices"
62+
)
63+
partition_count = len(partitions)
64+
for i in range(partition_count):
65+
for v in partitions[i]:
66+
if nx_graph.nodes[v] is None:
67+
raise ValueError(
68+
"The partition must contain arrays of vertices in the graph"
69+
)
70+
nx_graph.nodes[v]["subset"] = i
71+
# Add missing vertices to their own side
72+
for v in nx_graph.nodes:
73+
if "subset" not in nx_graph.nodes[v]:
74+
nx_graph.nodes[v]["subset"] = partition_count
75+
76+
auto_layout = automatic_layouts["partite"](
77+
nx_graph, scale=layout_scale, **layout_config
78+
)
79+
return dict([(k, np.append(v, [0])) for k, v in auto_layout.items()])
80+
elif layout == "random":
81+
# the random layout places coordinates in [0, 1)
82+
# we need to rescale manually afterwards...
83+
auto_layout = automatic_layouts["random"](nx_graph, **layout_config)
84+
for k, v in auto_layout.items():
85+
auto_layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5]))
86+
return dict([(k, np.append(v, [0])) for k, v in auto_layout.items()])
87+
else:
88+
raise ValueError(
89+
f"The layout '{layout}' is neither a recognized automatic layout, "
90+
"nor a vertex placement dictionary."
91+
)
92+
93+
1994
def _tree_layout(
2095
G: nx.classes.graph.Graph,
2196
root_vertex: Union[Hashable, None],
@@ -289,77 +364,14 @@ def __init__(
289364
nx_graph.add_edges_from(edges)
290365
self._graph = nx_graph
291366

292-
automatic_layouts = {
293-
"circular": nx.layout.circular_layout,
294-
"kamada_kawai": nx.layout.kamada_kawai_layout,
295-
"planar": nx.layout.planar_layout,
296-
"random": nx.layout.random_layout,
297-
"shell": nx.layout.shell_layout,
298-
"spectral": nx.layout.spectral_layout,
299-
"partite": nx.layout.multipartite_layout,
300-
"tree": _tree_layout,
301-
"spiral": nx.layout.spiral_layout,
302-
"spring": nx.layout.spring_layout,
303-
}
304-
305-
custom_layouts = ["random", "partite", "tree"]
306-
307-
if layout_config is None:
308-
layout_config = {}
309-
310-
if isinstance(layout, dict):
311-
self._layout = layout
312-
elif layout in automatic_layouts and layout not in custom_layouts:
313-
self._layout = automatic_layouts[layout](
314-
nx_graph, scale=layout_scale, **layout_config
315-
)
316-
self._layout = dict(
317-
[(k, np.append(v, [0])) for k, v in self._layout.items()]
318-
)
319-
elif layout == "tree":
320-
self._layout = automatic_layouts[layout](
321-
nx_graph,
322-
root_vertex=root_vertex,
323-
scale=layout_scale,
324-
)
325-
elif layout == "partite":
326-
if partitions is None or len(partitions) == 0:
327-
raise ValueError(
328-
"The partite layout requires the 'partitions' parameter to contain the partition of the vertices"
329-
)
330-
partition_count = len(partitions)
331-
for i in range(partition_count):
332-
for v in partitions[i]:
333-
if nx_graph.nodes[v] is None:
334-
raise ValueError(
335-
"The partition must contain arrays of vertices in the graph"
336-
)
337-
nx_graph.nodes[v]["subset"] = i
338-
# Add missing vertices to their own side
339-
for v in nx_graph.nodes:
340-
if "subset" not in nx_graph.nodes[v]:
341-
nx_graph.nodes[v]["subset"] = partition_count
342-
343-
self._layout = automatic_layouts["partite"](
344-
nx_graph, scale=layout_scale, **layout_config
345-
)
346-
self._layout = dict(
347-
[(k, np.append(v, [0])) for k, v in self._layout.items()]
348-
)
349-
elif layout == "random":
350-
# the random layout places coordinates in [0, 1)
351-
# we need to rescale manually afterwards...
352-
self._layout = automatic_layouts["random"](nx_graph, **layout_config)
353-
for k, v in self._layout.items():
354-
self._layout[k] = 2 * layout_scale * (v - np.array([0.5, 0.5]))
355-
self._layout = dict(
356-
[(k, np.append(v, [0])) for k, v in self._layout.items()]
357-
)
358-
else:
359-
raise ValueError(
360-
f"The layout '{layout}' is neither a recognized automatic layout, "
361-
"nor a vertex placement dictionary."
362-
)
367+
self._layout = _determine_graph_layout(
368+
nx_graph,
369+
layout=layout,
370+
layout_scale=layout_scale,
371+
layout_config=layout_config,
372+
partitions=partitions,
373+
root_vertex=root_vertex,
374+
)
363375

364376
if isinstance(labels, dict):
365377
self._labels = labels
@@ -475,3 +487,43 @@ def construct(self):
475487
476488
"""
477489
return Graph(list(nxgraph.nodes), list(nxgraph.edges), **kwargs)
490+
491+
def change_layout(
492+
self,
493+
layout: Union[str, dict] = "spring",
494+
layout_scale: float = 2,
495+
layout_config: Union[dict, None] = None,
496+
partitions: Union[List[List[Hashable]], None] = None,
497+
root_vertex: Union[Hashable, None] = None,
498+
) -> "Graph":
499+
"""Change the layout of this graph.
500+
501+
See the documentation of :class:`~.Graph` for details about the
502+
keyword arguments.
503+
504+
Examples
505+
--------
506+
507+
.. manim:: ChangeGraphLayout
508+
509+
class ChangeGraphLayout(Scene):
510+
def construct(self):
511+
G = Graph([1, 2, 3, 4, 5], [(1, 2), (2, 3), (3, 4), (4, 5)],
512+
layout={1: [-2, 0, 0], 2: [-1, 0, 0], 3: [0, 0, 0],
513+
4: [1, 0, 0], 5: [2, 0, 0]}
514+
)
515+
self.play(ShowCreation(G))
516+
self.play(G.animate.change_layout("circular"))
517+
self.wait()
518+
"""
519+
self._layout = _determine_graph_layout(
520+
self._graph,
521+
layout=layout,
522+
layout_scale=layout_scale,
523+
layout_config=layout_config,
524+
partitions=partitions,
525+
root_vertex=root_vertex,
526+
)
527+
for v in self.vertices:
528+
self[v].move_to(self._layout[v])
529+
return self

0 commit comments

Comments
 (0)