Skip to content

Commit e06661e

Browse files
committed
refactored nth product for ease of reading.
1 parent c2719f4 commit e06661e

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

graph/generate.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from functools import lru_cache
2-
from collections import defaultdict
31
import math
42
import random
53
import itertools
@@ -42,19 +40,23 @@ def grid(length, width, bidirectional=False):
4240
return g
4341

4442

45-
def nth_permutation(idx, length, alphabet=None, prefix=()):
46-
if alphabet is None:
47-
alphabet = [i for i in range(length)]
48-
if length == 0:
49-
return prefix
50-
else:
51-
branch_count = math.factorial(length - 1)
52-
for d in alphabet:
53-
if d not in prefix:
54-
if branch_count <= idx:
55-
idx -= branch_count
56-
else:
57-
return nth_permutation(idx, length - 1, alphabet, prefix + (d,))
43+
def _nth_product_no_typecheck(idx, *args):
44+
"""returns the nth product of the given iterables.
45+
46+
Args:
47+
idx (int): the index.
48+
*args: the iterables.
49+
50+
Returns:
51+
tuple: the elements at the given index.
52+
"""
53+
elements = ()
54+
for i in range(len(args)):
55+
offset = math.prod([len(a) for a in args[i:]]) // len(args[i])
56+
index = idx // offset
57+
elements += (args[i][index],)
58+
idx -= index * offset
59+
return elements
5860

5961

6062
def nth_product(idx, *args):
@@ -71,19 +73,13 @@ def nth_product(idx, *args):
7173
idx += total
7274
if idx < 0 or idx >= total:
7375
raise IndexError(f"Index {idx} out of range")
74-
75-
elements = ()
76-
for i in range(len(args)):
77-
offset = math.prod([len(a) for a in args[i:]]) // len(args[i])
78-
index = idx // offset
79-
elements += (args[i][index],)
80-
idx -= index * offset
81-
return elements
76+
return _nth_product_no_typecheck(idx, *args)
8277

8378

84-
def n_products(n, *args):
79+
def nth_products(n, *args):
8580
"""
86-
Returns the nth product of the given iterables.
81+
Returns the n evenly spread combinations using
82+
nth product of the given iterables.
8783
8884
Args:
8985
n (int): the number of products to generate.
@@ -99,10 +95,20 @@ def n_products(n, *args):
9995

10096
for ni in range(n):
10197
ix = int(step * ni + step / 2)
102-
yield nth_product(ix, *args)
98+
yield _nth_product_no_typecheck(ix, *args)
10399

104100

105101
def random_graph(size, degree=1.7, seed=1):
102+
"""Generates a graph with randomized edges
103+
104+
Args:
105+
size (int): number of nodes
106+
degree (float, optional): Average degree of connectivity. Defaults to 1.7.
107+
seed (int, optional): Random seed. Defaults to 1.
108+
109+
Returns:
110+
Graph: the generated graph
111+
"""
106112
if not isinstance(size, int):
107113
raise TypeError(f"Expected int, not {type(size)}")
108114
if not isinstance(degree, float):
@@ -117,7 +123,7 @@ def random_graph(size, degree=1.7, seed=1):
117123

118124
edges = int(size * degree)
119125

120-
L = n_products(edges, nodes, nodes)
126+
L = nth_products(edges, nodes, nodes)
121127
for a, b in L:
122128
g.add_edge(a, b)
123129
return g

tests/test_generate.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math, itertools
22
from graph.core import Graph
3-
from graph.generate import binary_tree, grid, random_graph, n_products, nth_product
3+
from graph.generate import binary_tree, grid, random_graph, nth_products, nth_product
44

55
from tests.test_graph import graph3x3, graph4x4, graph5x5
66

@@ -39,11 +39,11 @@ def test_nth_p():
3939

4040
def test_nth_products():
4141
a, b, c = [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]
42-
L = list(n_products(10, a, b))
42+
L = list(nth_products(10, a, b))
4343
assert len(L) == 10
4444
assert L == [(1, 2), (1, 6), (2, 4), (3, 1), (3, 5), (4, 2), (4, 6), (5, 4), (6, 1), (6, 5)]
4545

46-
L = list(n_products(10, a, b, c))
46+
L = list(nth_products(10, a, b, c))
4747
assert len(L) == 10
4848
# fmt:off
4949
assert L == [(1, 2, 5), (1, 6, 3), (2, 4, 1), (3, 1, 4), (3, 5, 2), (4, 2, 5), (4, 6, 3), (5, 4, 1), (6, 1, 4), (6, 5, 2)]
@@ -52,7 +52,9 @@ def test_nth_products():
5252

5353
def test_random_graphs():
5454
g1 = random_graph(10, 1.7, 1)
55-
g2 = random_graph(10, 1.7, 1)
55+
assert len(g1.edges()) == 10 * 1.7
56+
57+
g2 = random_graph(10, 1.7, 1) # same seed, same graph
5658
assert g1 == g2
57-
g3 = random_graph(10, 1.1, 1)
58-
assert g1 != g3
59+
g3 = random_graph(10, 1.1, 1) # same seed, different graph
60+
assert len(g3.edges()) == 10 * 1.1

0 commit comments

Comments
 (0)