|
4 | 4 |
|
5 | 5 | import inspect
|
6 | 6 | import timeit
|
7 |
| -from typing import Generator |
| 7 | +from itertools import product |
| 8 | +from typing import Generator, Union |
8 | 9 |
|
9 |
| -from tests.performance._constants import GRAPH_SETUP_CODES, SETUP_CODES |
10 |
| - |
11 |
| - |
12 |
| -def do_performance_test(code_to_test: str | dict[str, str], array_sizes: list[int], repeats: int): |
13 |
| - """Run the performance test for the given code.""" |
14 | 10 |
|
| 11 | +def do_performance_test( |
| 12 | + code_to_test: Union[str, dict[str, str], list[str]], |
| 13 | + size_list: list[int], |
| 14 | + repeats: int, |
| 15 | + setup_codes: dict[str, str], |
| 16 | +): |
| 17 | + """Generalized performance test runner.""" |
15 | 18 | print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}")
|
16 | 19 |
|
17 |
| - for array_size in array_sizes: |
| 20 | + for size in size_list: |
| 21 | + formatted_setup_codes = {key: code.format(size=size) for key, code in setup_codes.items()} |
18 | 22 | if isinstance(code_to_test, dict):
|
19 |
| - code_to_test_list = [code_to_test[variant].format(array_size=array_size) for variant in SETUP_CODES] |
20 |
| - else: |
21 |
| - code_to_test_list = [code_to_test.format(array_size=array_size)] * len(SETUP_CODES) |
22 |
| - print(f"\n\tArray size: {array_size}\n") |
23 |
| - setup_codes = [setup_code.format(array_size=array_size) for setup_code in SETUP_CODES.values()] |
24 |
| - timings = _get_timings(setup_codes, code_to_test_list, repeats) |
25 |
| - |
26 |
| - if code_to_test == "pass": |
27 |
| - _print_timings(timings, list(SETUP_CODES.keys()), setup_codes) |
| 23 | + code_to_test_list = [code_to_test[variant].format(size=size) for variant in setup_codes] |
| 24 | + test_generator = zip(formatted_setup_codes.items(), code_to_test_list) |
| 25 | + elif isinstance(code_to_test, list): |
| 26 | + code_to_test_list = [code.format(size=size) for code in code_to_test] |
| 27 | + test_generator = product(formatted_setup_codes.items(), code_to_test_list) |
28 | 28 | else:
|
29 |
| - _print_timings(timings, list(SETUP_CODES.keys()), code_to_test_list) |
30 |
| - print() |
| 29 | + test_generator = product(formatted_setup_codes.items(), [code_to_test.format(size=size)]) |
31 | 30 |
|
| 31 | + print(f"\n\tsize: {size}\n") |
32 | 32 |
|
33 |
| -def do_graph_test(code_to_test: str | dict[str, str], graph_sizes: list[int], repeats: int): |
34 |
| - """Run the performance test for the given code.""" |
| 33 | + timings = _get_timings(test_generator, repeats=repeats) |
| 34 | + _print_timings(timings) |
35 | 35 |
|
36 |
| - print(f"{'-' * 20} {inspect.stack()[1][3]} {'-' * 20}") |
37 |
| - |
38 |
| - for graph_size in graph_sizes: |
39 |
| - if isinstance(code_to_test, dict): |
40 |
| - code_to_test_list = [code_to_test[variant] for variant in GRAPH_SETUP_CODES] |
41 |
| - else: |
42 |
| - code_to_test_list = [code_to_test] * len(GRAPH_SETUP_CODES) |
43 |
| - print(f"\n\tGraph size: {graph_size}\n") |
44 |
| - setup_codes = [setup_code.format(graph_size=graph_size) for setup_code in GRAPH_SETUP_CODES.values()] |
45 |
| - timings = _get_timings(setup_codes, code_to_test_list, repeats) |
46 |
| - |
47 |
| - if code_to_test == "pass": |
48 |
| - _print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), setup_codes) |
49 |
| - else: |
50 |
| - _print_graph_timings(timings, list(GRAPH_SETUP_CODES.keys()), code_to_test_list) |
51 | 36 | print()
|
52 | 37 |
|
53 | 38 |
|
54 |
| -def _print_test_code(code: str | dict[str, str], repeats: int): |
55 |
| - print(f"{'-' * 40}") |
56 |
| - if isinstance(code, dict): |
57 |
| - for variant, code_variant in code.items(): |
58 |
| - print(f"{variant}") |
59 |
| - print(f"\t{code_variant} (x {repeats})") |
60 |
| - return |
61 |
| - print(f"{code} (x {repeats})") |
62 |
| - |
63 |
| - |
64 |
| -def _print_graph_timings(timings: Generator, graph_types: list[str], code_list: list[str]): |
65 |
| - for graph_type, timing, code in zip(graph_types, timings, code_list): |
66 |
| - if ";" in code: |
67 |
| - code = code.split(";")[-1] |
68 |
| - |
69 |
| - code = code.replace("\n", " ").replace("\t", " ") |
70 |
| - code = f"{graph_type}: " + code |
71 |
| - |
72 |
| - if isinstance(timing, Exception): |
73 |
| - print(f"\t\t{code.ljust(100)} | Not supported") |
74 |
| - continue |
75 |
| - print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s") |
76 |
| - |
77 |
| - |
78 |
| -def _print_timings(timings: Generator, array_types: list[str], code_list: list[str]): |
79 |
| - for array, timing, code in zip(array_types, timings, code_list): |
80 |
| - if ";" in code: |
81 |
| - code = code.split(";")[-1] |
82 |
| - |
83 |
| - code = code.replace("\n", " ").replace("\t", " ") |
84 |
| - array_name = f"{array}_array" |
85 |
| - code = code.replace("input_array", array_name) |
| 39 | +def _print_timings(timings: Generator): |
| 40 | + for key, code, timing in timings: |
| 41 | + code = code.split(";")[-1].replace("\n", " ").replace("\t", " ") |
| 42 | + code = f"{key}: {code}" |
86 | 43 |
|
87 | 44 | if isinstance(timing, Exception):
|
88 | 45 | print(f"\t\t{code.ljust(100)} | Not supported")
|
89 | 46 | continue
|
90 | 47 | print(f"\t\t{code.ljust(100)} | {sum(timing):.2f}s")
|
91 | 48 |
|
92 | 49 |
|
93 |
| -def _get_timings(setup_codes: list[str], test_codes: list[str], repeats: int): |
| 50 | +def _get_timings(test_generator, repeats: int): |
94 | 51 | """Return a generator with the timings for each array type."""
|
95 |
| - for setup_code, test_code in zip(setup_codes, test_codes): |
| 52 | + for (key, setup_code), test_code in test_generator: |
96 | 53 | if test_code == "pass":
|
97 |
| - yield timeit.repeat(setup_code, number=repeats) |
| 54 | + yield key, "intialise", timeit.repeat(setup_code, number=repeats) |
98 | 55 | else:
|
99 | 56 | try:
|
100 |
| - yield timeit.repeat(test_code, setup_code, number=repeats) |
| 57 | + yield key, test_code, timeit.repeat(test_code, setup_code, number=repeats) |
101 | 58 | # pylint: disable=broad-exception-caught
|
102 | 59 | except Exception as error: # noqa
|
103 |
| - yield error |
| 60 | + yield key, test_code, error |
0 commit comments