Skip to content

Commit caa7678

Browse files
committed
Added generator for nth product and nth_products.
1 parent 649bd23 commit caa7678

File tree

2 files changed

+119
-2
lines changed

2 files changed

+119
-2
lines changed

graph/generate.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from functools import lru_cache
2+
from collections import defaultdict
3+
import math
4+
import random
15
import itertools
26
from graph.core import Graph
37

@@ -36,3 +40,84 @@ def grid(length, width, bidirectional=False):
3640
a, b = node_index[(i, j)], node_index[(i, j - 1)]
3741
g.add_edge(b, a, bidirectional=bidirectional)
3842
return g
43+
44+
45+
def nth_permutation(idx, length, alphabet=None, prefix=()):
46+
if alphabet is None:
47+
alphabet = [i for i in range(length)]
48+
if length == 0:
49+
return prefix
50+
else:
51+
branch_count = math.factorial(length - 1)
52+
for d in alphabet:
53+
if d not in prefix:
54+
if branch_count <= idx:
55+
idx -= branch_count
56+
else:
57+
return nth_permutation(idx, length - 1, alphabet, prefix + (d,))
58+
59+
60+
def nth_product(idx, *args):
61+
"""returns the nth product of the given iterables.
62+
63+
Args:
64+
idx (int): the index.
65+
*args: the iterables.
66+
"""
67+
if not isinstance(idx, int):
68+
raise TypeError(f"Expected int, not {type(idx)}")
69+
total = math.prod([len(a) for a in args])
70+
if idx < 0:
71+
idx += total
72+
if index < 0 or index >= total:
73+
raise IndexError(f"Index {index} out of range")
74+
75+
elements = ()
76+
for i in range(len(args)):
77+
offset = math.prod([len(a) for a in args[i:]]) // len(args[i])
78+
index = idx // offset
79+
elements += (args[i][index],)
80+
idx -= index * offset
81+
return elements
82+
83+
84+
def n_products(*args, n=20):
85+
"""
86+
Returns the nth product of the given iterables.
87+
"""
88+
if len(args) == 0:
89+
return ()
90+
if any(len(a) == 0 for a in args):
91+
raise ZeroDivisionError("Cannot generate products of empty iterables")
92+
93+
n = min(n, math.prod([len(a) for a in args]))
94+
step = math.prod([len(a) for a in args]) / n
95+
96+
for ni in range(n):
97+
ix = int(step * ni + step / 2)
98+
yield nth_product(ix, *args)
99+
100+
101+
def random_graph(size, degree=1.7, seed=1):
102+
if not isinstance(size, int):
103+
raise TypeError(f"Expected int, not {type(size)}")
104+
if not isinstance(degree, float):
105+
raise TypeError(f"Expected float, not {type(degree)}")
106+
if not isinstance(seed, int):
107+
raise TypeError(f"Expected int, not {type(seed)}")
108+
109+
g = Graph()
110+
nodes = list(range(size))
111+
rng = random.Random(seed)
112+
rng.shuffle(nodes)
113+
114+
edges = size * degree
115+
116+
gen_length = math.factorial(nodes) / math.factorial(nodes - 2)
117+
comb = int(gen_length / edges)
118+
119+
for i in nodes:
120+
n = i * comb
121+
a, b = nth_permutation(n, size, nodes)
122+
g.add_edge(a, b)
123+
return g

tests/test_generate.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import math, itertools
12
from graph.core import Graph
2-
from graph.generate import binary_tree, grid
3+
from graph.generate import binary_tree, grid, random_graph, n_products, nth_product
34

45
from tests.test_graph import graph3x3, graph4x4, graph5x5
56

@@ -15,12 +16,43 @@ def test_generate_grid_4x4_bidirectional():
1516
expected = graph4x4()
1617
assert g == expected
1718

19+
1820
def test_generate_grid_5x5_bidirectional():
1921
g = grid(5, 5, bidirectional=True)
2022
expected = graph5x5()
2123
assert g == expected
2224

25+
2326
def test_generate_binary_tree_3_levels():
2427
g = binary_tree(3)
2528
expected = Graph(from_list=[[0, 1], [0, 2], [1, 3], [1, 4], [2, 5], [2, 6]])
26-
assert g == expected
29+
assert g == expected
30+
31+
32+
def test_nth_p():
33+
a, b, c, d = [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5], [1, 2, 3, 4], [1, 2, 3, 4, 5, 6, 7]
34+
for i, comb in enumerate(itertools.product(*[a, b, c, d])):
35+
if i < 5 or i > 836 or i == 444:
36+
print(i, comb)
37+
assert comb == nth_product(i, a, b, c, d)
38+
39+
40+
def test_nth_products():
41+
a, b, c = [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]
42+
L = list(n_products(a, b, n=10))
43+
assert len(L) == 10
44+
assert L == [(1, 2), (1, 6), (2, 4), (3, 1), (3, 5), (4, 2), (4, 6), (5, 4), (6, 1), (6, 5)]
45+
46+
L = list(n_products(a, b, c, n=10))
47+
assert len(L) == 10
48+
# fmt:off
49+
assert L == [(1, 2, 5), (1, 6, 3), (2, 4, 1), (3, 1, 4), (3, 5, 2), (4, 2, 5), (4, 6, 3), (5, 4, 1), (6, 1, 4), (6, 5, 2)]
50+
# fmt:on
51+
52+
53+
def test_random_graphs():
54+
g1 = random_graph(10, 1.7, 1)
55+
g2 = random_graph(10, 1.7, 1)
56+
assert g1 == g2
57+
g3 = random_graph(10, 1.1, 1)
58+
assert g1 != g3

0 commit comments

Comments
 (0)