Skip to content

Commit cf9e14d

Browse files
committed
Added parametric tests for cover algorithms
1 parent f4c74cc commit cf9e14d

1 file changed

Lines changed: 118 additions & 150 deletions

File tree

tests/test_unit_cover.py

Lines changed: 118 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
1+
"""
2+
Unit tests for the cover algorithms.
3+
"""
4+
15
import numpy as np
26
import pytest
37

48
from tdamapper.core import TrivialCover
5-
from tdamapper.cover import BallCover, CubicalCover, KNNCover
9+
from tdamapper.cover import (
10+
BallCover,
11+
CubicalCover,
12+
KNNCover,
13+
ProximityCubicalCover,
14+
StandardCubicalCover,
15+
)
16+
from tdamapper.utils.unionfind import UnionFind
617

718

819
def dataset_simple():
920
"""
1021
Create a simple dataset of points in a 2D space.
22+
23+
This dataset consists of four points forming the corners of a rectangle
24+
such that two sides are longer than the other two.
1125
"""
1226
return [
1327
np.array([0.0, 1.0]),
@@ -46,6 +60,13 @@ def dataset_grid(num=1000):
4660
return grid
4761

4862

63+
SIMPLE = dataset_simple()
64+
65+
TWO_LINES = dataset_two_lines()
66+
67+
GRID = dataset_grid(10)
68+
69+
4970
def assert_coverage(data, cover):
5071
"""
5172
Assert that the cover applies to the data and covers all points.
@@ -73,170 +94,117 @@ def count_components(charts):
7394
point_charts[point_id] = []
7495
point_charts[point_id].append(chart_id)
7596

76-
chart_components = {x: x for x in range(len(charts))}
97+
uf = UnionFind(list(range(len(charts))))
7798
for point_id, chart_ids in point_charts.items():
78-
if len(chart_ids) > 1:
79-
# Union all chart ids for this point
80-
first_chart = chart_ids[0]
81-
for chart_id in chart_ids[1:]:
82-
chart_components[chart_id] = chart_components[first_chart]
83-
# Count unique components
84-
unique_components = set(chart_components.values())
99+
for i in range(len(chart_ids) - 1):
100+
uf.union(chart_ids[i], chart_ids[i + 1])
101+
# Count the number of unique components
102+
unique_components = set()
103+
for chart_id in range(len(charts)):
104+
unique_components.add(uf.find(chart_id))
85105
return len(unique_components)
86106

87107

88-
def test_trivial_cover_random():
89-
data = dataset_random()
90-
cover = TrivialCover()
91-
assert_coverage(data, cover)
92-
93-
94-
def test_trivial_cover_two_lines():
95-
data = dataset_two_lines()
96-
cover = TrivialCover()
97-
charts = assert_coverage(data, cover)
98-
assert 1 == len(charts)
99-
num_components = count_components(charts)
100-
assert 1 == num_components
101-
102-
103108
@pytest.mark.parametrize(
104109
"dataset, cover, num_charts, num_components",
105110
[
106-
# Simple dataset tests
107-
(dataset_simple(), TrivialCover(), 1, 1),
108-
(dataset_simple(), BallCover(radius=1.1, metric="euclidean"), 2, 2),
109-
(dataset_simple(), KNNCover(neighbors=2, metric="euclidean"), 2, 2),
110-
(dataset_simple(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
111-
# Two lines dataset tests
112-
(dataset_two_lines(), TrivialCover(), 1, 1),
113-
(dataset_two_lines(), BallCover(radius=0.2, metric="euclidean"), None, 2),
114-
(dataset_two_lines(), KNNCover(neighbors=10, metric="euclidean"), None, 2),
115-
(dataset_two_lines(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
116-
# Grid dataset tests
117-
(dataset_grid(), TrivialCover(), 1, 1),
118-
(dataset_grid(), BallCover(radius=0.05, metric="euclidean"), None, 1),
119-
(dataset_grid(), KNNCover(neighbors=10, metric="euclidean"), None, 1),
120-
(dataset_grid(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None),
111+
# Simple dataset
112+
(SIMPLE, TrivialCover(), 1, 1),
113+
# BallCover: components are expected to merge when the radius crosses the
114+
# lenghts of the rectangle sides.
115+
(SIMPLE, BallCover(radius=0.9, metric="euclidean"), 4, 4),
116+
(SIMPLE, BallCover(radius=1.1, metric="euclidean"), 2, 2),
117+
(SIMPLE, BallCover(radius=1.5, metric="euclidean"), 1, 1),
118+
# KNNCover: components are expected to merge when the number of neighbors
119+
# is enough to cover a given number of rectangle sides.
120+
(SIMPLE, KNNCover(neighbors=1, metric="euclidean"), 4, 4),
121+
(SIMPLE, KNNCover(neighbors=2, metric="euclidean"), 2, 2),
122+
(SIMPLE, KNNCover(neighbors=3, metric="euclidean"), 2, 1),
123+
# StandardCubicalCover: components are expected to merge when intervals
124+
# are big enough to cover the rectangle sides.
125+
(SIMPLE, StandardCubicalCover(n_intervals=2, overlap_frac=0.1), 4, 4),
126+
(SIMPLE, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 4),
127+
(SIMPLE, StandardCubicalCover(n_intervals=1, overlap_frac=0.5), 1, 1),
128+
(SIMPLE, ProximityCubicalCover(n_intervals=2, overlap_frac=0.1), 4, 4),
129+
(SIMPLE, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 4),
130+
(SIMPLE, ProximityCubicalCover(n_intervals=1, overlap_frac=0.5), 1, 1),
131+
# Two lines dataset
132+
(TWO_LINES, TrivialCover(), 1, 1),
133+
# BallCover: components are expected to merge when the radius crosses the
134+
# distance between the two lines.
135+
(TWO_LINES, BallCover(radius=0.2, metric="euclidean"), 10, 2),
136+
(TWO_LINES, BallCover(radius=0.5, metric="euclidean"), 4, 2),
137+
(TWO_LINES, BallCover(radius=1.0, metric="euclidean"), 4, 2),
138+
(TWO_LINES, BallCover(radius=1.1, metric="euclidean"), 2, 1),
139+
(TWO_LINES, BallCover(radius=1.5, metric="euclidean"), 1, 1),
140+
# KNNCover: components are expected to merge when the number of neighbors
141+
# is more than the cardinality of a single line.
142+
(TWO_LINES, KNNCover(neighbors=3, metric="euclidean"), None, 2),
143+
(TWO_LINES, KNNCover(neighbors=10, metric="euclidean"), None, 2),
144+
(TWO_LINES, KNNCover(neighbors=100, metric="euclidean"), None, 2),
145+
(TWO_LINES, KNNCover(neighbors=1001, metric="euclidean"), 2, 1),
146+
# StandardCubicalCover: components are expected to merge when intervals
147+
# are big enough to cover the distance between the two lines.
148+
(TWO_LINES, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 2),
149+
(TWO_LINES, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 2),
150+
# Grid dataset
151+
(GRID, TrivialCover(), 1, 1),
152+
# BallCover: components are expected to jump from many singletons sets
153+
# to a single one when the radius crosses the grid spacing.
154+
(GRID, BallCover(radius=0.01, metric="euclidean"), 100, 100),
155+
(GRID, BallCover(radius=0.2, metric="euclidean"), None, 1),
156+
# KNNCover: components are expected to merge when the number of neighbors
157+
# is more than the number of adjacent points in the grid.
158+
(GRID, KNNCover(neighbors=1, metric="euclidean"), 100, 100),
159+
(GRID, KNNCover(neighbors=10, metric="euclidean"), None, 1),
160+
(GRID, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1),
161+
(GRID, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1),
121162
],
122163
)
123164
def test_cover(dataset, cover, num_charts, num_components):
165+
"""
166+
Test that the cover algorithm covers the dataset correctly, and that the
167+
number of charts and components is as expected. If num_charts or
168+
num_components is None, the test will not check that value.
169+
"""
124170
charts = assert_coverage(dataset, cover)
125171
if num_charts is not None:
126172
assert len(charts) == num_charts
127173
if num_components is not None:
128174
assert count_components(charts) == num_components
129175

130176

131-
def test_trivial_cover_grid():
132-
data = dataset_two_lines()
133-
cover = TrivialCover()
134-
charts = assert_coverage(data, cover)
135-
assert 1 == len(charts)
136-
num_components = count_components(charts)
137-
assert 1 == num_components
138-
139-
140-
def test_ball_cover_simple():
141-
data = [
142-
np.array([0.0, 1.0]),
143-
np.array([1.0, 0.0]),
144-
np.array([0.0, 0.0]),
145-
np.array([1.0, 1.0]),
146-
]
147-
cover = BallCover(radius=1.1, metric="euclidean")
148-
charts = assert_coverage(data, cover)
149-
assert 2 == len(charts)
150-
num_components = count_components(charts)
151-
assert 1 == num_components
152-
153-
154-
def test_ball_cover_random():
155-
data = dataset_random(dim=2, num=10)
156-
cover = BallCover(radius=0.2, metric="euclidean")
157-
assert_coverage(data, cover)
158-
159-
160-
def test_ball_cover_two_lines():
161-
data = dataset_two_lines()
162-
cover = BallCover(radius=0.2, metric="euclidean")
163-
charts = assert_coverage(data, cover)
164-
num_components = count_components(charts)
165-
assert 2 == num_components
166-
167-
168-
def test_ball_cover_grid():
169-
data = dataset_grid(num=100)
170-
cover = BallCover(radius=0.05, metric="euclidean")
171-
charts = assert_coverage(data, cover)
172-
num_components = count_components(charts)
173-
assert 1 == num_components
174-
175-
176-
def test_knn_cover_simple():
177-
data = [
178-
np.array([0.0, 1.0]),
179-
np.array([1.1, 0.0]),
180-
np.array([0.0, 0.0]),
181-
np.array([1.1, 1.0]),
182-
]
183-
cover = KNNCover(neighbors=2, metric="euclidean")
184-
charts = assert_coverage(data, cover)
185-
assert 2 == len(charts)
186-
187-
188-
def test_knn_cover_two_lines():
189-
data = dataset_two_lines()
190-
cover = KNNCover(neighbors=10, metric="euclidean")
191-
charts = assert_coverage(data, cover)
192-
num_components = count_components(charts)
193-
assert 2 == num_components
194-
195-
196-
def test_knn_cover_grid():
197-
data = dataset_grid(num=100)
198-
cover = KNNCover(neighbors=10, metric="euclidean")
199-
charts = assert_coverage(data, cover)
200-
num_components = count_components(charts)
201-
assert 1 == num_components
202-
203-
204-
def test_cubical_cover_simple():
205-
data = [
206-
np.array([0.0, 1.0]),
207-
np.array([1.1, 0.0]),
208-
np.array([0.0, 0.0]),
209-
np.array([1.1, 1.0]),
210-
]
211-
cover = CubicalCover(n_intervals=2, overlap_frac=0.5)
212-
charts = list(cover.apply(data))
213-
assert 4 == len(charts)
214-
215-
216-
def test_cubical_cover_random():
217-
data = dataset_random(dim=2, num=100)
218-
cover = CubicalCover(n_intervals=5, overlap_frac=0.1)
219-
assert_coverage(data, cover)
220-
221-
222-
def test_cubical_cover_params():
223-
cover = CubicalCover(n_intervals=2, overlap_frac=0.5)
177+
@pytest.mark.parametrize(
178+
"cover, params",
179+
[
180+
(TrivialCover(), {}),
181+
(
182+
BallCover(radius=0.2, metric="euclidean"),
183+
{"radius": 0.21, "metric": "euclidean"},
184+
),
185+
(
186+
KNNCover(neighbors=10, metric="euclidean"),
187+
{"neighbors": 13, "metric": "euclidean"},
188+
),
189+
(
190+
StandardCubicalCover(n_intervals=2, overlap_frac=0.5),
191+
{"n_intervals": 4, "overlap_frac": 0.145},
192+
),
193+
(
194+
ProximityCubicalCover(n_intervals=2, overlap_frac=0.5),
195+
{"n_intervals": 4, "overlap_frac": 0.145},
196+
),
197+
(
198+
CubicalCover(n_intervals=2, overlap_frac=0.5, algorithm="standard"),
199+
{"n_intervals": 4, "overlap_frac": 0.145, "algorithm": "proximity"},
200+
),
201+
],
202+
)
203+
def test_params(cover, params):
204+
"""
205+
Test that the cover can get and set parameters correctly.
206+
"""
207+
cover.set_params(**params)
224208
params = cover.get_params(deep=True)
225-
assert 2 == params["n_intervals"]
226-
assert 0.5 == params["overlap_frac"]
227-
228-
229-
def test_standard_cover_simple():
230-
data = [
231-
np.array([0.0, 1.0]),
232-
np.array([1.1, 0.0]),
233-
np.array([0.0, 0.0]),
234-
np.array([1.1, 1.0]),
235-
]
236-
cover = CubicalCover(
237-
n_intervals=2,
238-
overlap_frac=0.5,
239-
algorithm="standard",
240-
)
241-
charts = list(cover.apply(data))
242-
assert 4 == len(charts)
209+
for k, v in params.items():
210+
assert params[k] == v

0 commit comments

Comments
 (0)