Skip to content

Commit 75f0e37

Browse files
committed
minor update.
1 parent e06661e commit 75f0e37

File tree

1 file changed

+26
-29
lines changed

1 file changed

+26
-29
lines changed

graph/generate.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import random
33
import itertools
4+
from functools import reduce
5+
from operator import mul
46
from graph.core import Graph
57

68

@@ -40,40 +42,35 @@ def grid(length, width, bidirectional=False):
4042
return g
4143

4244

43-
def _nth_product_no_typecheck(idx, *args):
44-
"""returns the nth product of the given iterables.
45+
def nth_product(index, *args):
46+
"""Equivalent to ``list(product(*args))[index]``.
4547
46-
Args:
47-
idx (int): the index.
48-
*args: the iterables.
48+
The products of *args* can be ordered lexicographically.
49+
:func:`nth_product` computes the product at sort position *index* without
50+
computing the previous products.
4951
50-
Returns:
51-
tuple: the elements at the given index.
52+
>>> nth_product(8, range(2), range(2), range(2), range(2))
53+
(1, 0, 0, 0)
54+
55+
``IndexError`` will be raised if the given *index* is invalid.
5256
"""
53-
elements = ()
54-
for i in range(len(args)):
55-
offset = math.prod([len(a) for a in args[i:]]) // len(args[i])
56-
index = idx // offset
57-
elements += (args[i][index],)
58-
idx -= index * offset
59-
return elements
57+
pools = list(map(tuple, reversed(args)))
58+
ns = list(map(len, pools))
6059

60+
c = reduce(mul, ns)
6161

62-
def nth_product(idx, *args):
63-
"""returns the nth product of the given iterables.
62+
if index < 0:
63+
index += c
6464

65-
Args:
66-
idx (int): the index.
67-
*args: the iterables.
68-
"""
69-
if not isinstance(idx, int):
70-
raise TypeError(f"Expected int, not {type(idx)}")
71-
total = math.prod([len(a) for a in args])
72-
if idx < 0:
73-
idx += total
74-
if idx < 0 or idx >= total:
75-
raise IndexError(f"Index {idx} out of range")
76-
return _nth_product_no_typecheck(idx, *args)
65+
if not 0 <= index < c:
66+
raise IndexError
67+
68+
result = []
69+
for pool, n in zip(pools, ns):
70+
result.append(pool[index % n])
71+
index //= n
72+
73+
return tuple(reversed(result))
7774

7875

7976
def nth_products(n, *args):
@@ -95,7 +92,7 @@ def nth_products(n, *args):
9592

9693
for ni in range(n):
9794
ix = int(step * ni + step / 2)
98-
yield _nth_product_no_typecheck(ix, *args)
95+
yield nth_product(ix, *args)
9996

10097

10198
def random_graph(size, degree=1.7, seed=1):

0 commit comments

Comments
 (0)