|
| 1 | +""" |
| 2 | +Unit tests for the cover algorithms. |
| 3 | +""" |
| 4 | + |
1 | 5 | import numpy as np |
2 | 6 | import pytest |
3 | 7 |
|
4 | 8 | from tdamapper.core import TrivialCover |
5 | | -from tdamapper.cover import BallCover, CubicalCover, KNNCover |
| 9 | +from tdamapper.cover import ( |
| 10 | + BallCover, |
| 11 | + CubicalCover, |
| 12 | + KNNCover, |
| 13 | + ProximityCubicalCover, |
| 14 | + StandardCubicalCover, |
| 15 | +) |
| 16 | +from tdamapper.utils.unionfind import UnionFind |
6 | 17 |
|
7 | 18 |
|
8 | 19 | def dataset_simple(): |
9 | 20 | """ |
10 | 21 | Create a simple dataset of points in a 2D space. |
| 22 | +
|
| 23 | + This dataset consists of four points forming the corners of a rectangle |
| 24 | + such that two sides are longer than the other two. |
11 | 25 | """ |
12 | 26 | return [ |
13 | 27 | np.array([0.0, 1.0]), |
@@ -46,6 +60,13 @@ def dataset_grid(num=1000): |
46 | 60 | return grid |
47 | 61 |
|
48 | 62 |
|
| 63 | +SIMPLE = dataset_simple() |
| 64 | + |
| 65 | +TWO_LINES = dataset_two_lines() |
| 66 | + |
| 67 | +GRID = dataset_grid(10) |
| 68 | + |
| 69 | + |
49 | 70 | def assert_coverage(data, cover): |
50 | 71 | """ |
51 | 72 | Assert that the cover applies to the data and covers all points. |
@@ -73,170 +94,117 @@ def count_components(charts): |
73 | 94 | point_charts[point_id] = [] |
74 | 95 | point_charts[point_id].append(chart_id) |
75 | 96 |
|
76 | | - chart_components = {x: x for x in range(len(charts))} |
| 97 | + uf = UnionFind(list(range(len(charts)))) |
77 | 98 | for point_id, chart_ids in point_charts.items(): |
78 | | - if len(chart_ids) > 1: |
79 | | - # Union all chart ids for this point |
80 | | - first_chart = chart_ids[0] |
81 | | - for chart_id in chart_ids[1:]: |
82 | | - chart_components[chart_id] = chart_components[first_chart] |
83 | | - # Count unique components |
84 | | - unique_components = set(chart_components.values()) |
| 99 | + for i in range(len(chart_ids) - 1): |
| 100 | + uf.union(chart_ids[i], chart_ids[i + 1]) |
| 101 | + # Count the number of unique components |
| 102 | + unique_components = set() |
| 103 | + for chart_id in range(len(charts)): |
| 104 | + unique_components.add(uf.find(chart_id)) |
85 | 105 | return len(unique_components) |
86 | 106 |
|
87 | 107 |
|
88 | | -def test_trivial_cover_random(): |
89 | | - data = dataset_random() |
90 | | - cover = TrivialCover() |
91 | | - assert_coverage(data, cover) |
92 | | - |
93 | | - |
94 | | -def test_trivial_cover_two_lines(): |
95 | | - data = dataset_two_lines() |
96 | | - cover = TrivialCover() |
97 | | - charts = assert_coverage(data, cover) |
98 | | - assert 1 == len(charts) |
99 | | - num_components = count_components(charts) |
100 | | - assert 1 == num_components |
101 | | - |
102 | | - |
103 | 108 | @pytest.mark.parametrize( |
104 | 109 | "dataset, cover, num_charts, num_components", |
105 | 110 | [ |
106 | | - # Simple dataset tests |
107 | | - (dataset_simple(), TrivialCover(), 1, 1), |
108 | | - (dataset_simple(), BallCover(radius=1.1, metric="euclidean"), 2, 2), |
109 | | - (dataset_simple(), KNNCover(neighbors=2, metric="euclidean"), 2, 2), |
110 | | - (dataset_simple(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None), |
111 | | - # Two lines dataset tests |
112 | | - (dataset_two_lines(), TrivialCover(), 1, 1), |
113 | | - (dataset_two_lines(), BallCover(radius=0.2, metric="euclidean"), None, 2), |
114 | | - (dataset_two_lines(), KNNCover(neighbors=10, metric="euclidean"), None, 2), |
115 | | - (dataset_two_lines(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None), |
116 | | - # Grid dataset tests |
117 | | - (dataset_grid(), TrivialCover(), 1, 1), |
118 | | - (dataset_grid(), BallCover(radius=0.05, metric="euclidean"), None, 1), |
119 | | - (dataset_grid(), KNNCover(neighbors=10, metric="euclidean"), None, 1), |
120 | | - (dataset_grid(), CubicalCover(n_intervals=2, overlap_frac=0.5), 4, None), |
| 111 | + # Simple dataset |
| 112 | + (SIMPLE, TrivialCover(), 1, 1), |
| 113 | + # BallCover: components are expected to merge when the radius crosses the |
| 114 | + # lenghts of the rectangle sides. |
| 115 | + (SIMPLE, BallCover(radius=0.9, metric="euclidean"), 4, 4), |
| 116 | + (SIMPLE, BallCover(radius=1.1, metric="euclidean"), 2, 2), |
| 117 | + (SIMPLE, BallCover(radius=1.5, metric="euclidean"), 1, 1), |
| 118 | + # KNNCover: components are expected to merge when the number of neighbors |
| 119 | + # is enough to cover a given number of rectangle sides. |
| 120 | + (SIMPLE, KNNCover(neighbors=1, metric="euclidean"), 4, 4), |
| 121 | + (SIMPLE, KNNCover(neighbors=2, metric="euclidean"), 2, 2), |
| 122 | + (SIMPLE, KNNCover(neighbors=3, metric="euclidean"), 2, 1), |
| 123 | + # StandardCubicalCover: components are expected to merge when intervals |
| 124 | + # are big enough to cover the rectangle sides. |
| 125 | + (SIMPLE, StandardCubicalCover(n_intervals=2, overlap_frac=0.1), 4, 4), |
| 126 | + (SIMPLE, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 4), |
| 127 | + (SIMPLE, StandardCubicalCover(n_intervals=1, overlap_frac=0.5), 1, 1), |
| 128 | + (SIMPLE, ProximityCubicalCover(n_intervals=2, overlap_frac=0.1), 4, 4), |
| 129 | + (SIMPLE, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 4), |
| 130 | + (SIMPLE, ProximityCubicalCover(n_intervals=1, overlap_frac=0.5), 1, 1), |
| 131 | + # Two lines dataset |
| 132 | + (TWO_LINES, TrivialCover(), 1, 1), |
| 133 | + # BallCover: components are expected to merge when the radius crosses the |
| 134 | + # distance between the two lines. |
| 135 | + (TWO_LINES, BallCover(radius=0.2, metric="euclidean"), 10, 2), |
| 136 | + (TWO_LINES, BallCover(radius=0.5, metric="euclidean"), 4, 2), |
| 137 | + (TWO_LINES, BallCover(radius=1.0, metric="euclidean"), 4, 2), |
| 138 | + (TWO_LINES, BallCover(radius=1.1, metric="euclidean"), 2, 1), |
| 139 | + (TWO_LINES, BallCover(radius=1.5, metric="euclidean"), 1, 1), |
| 140 | + # KNNCover: components are expected to merge when the number of neighbors |
| 141 | + # is more than the cardinality of a single line. |
| 142 | + (TWO_LINES, KNNCover(neighbors=3, metric="euclidean"), None, 2), |
| 143 | + (TWO_LINES, KNNCover(neighbors=10, metric="euclidean"), None, 2), |
| 144 | + (TWO_LINES, KNNCover(neighbors=100, metric="euclidean"), None, 2), |
| 145 | + (TWO_LINES, KNNCover(neighbors=1001, metric="euclidean"), 2, 1), |
| 146 | + # StandardCubicalCover: components are expected to merge when intervals |
| 147 | + # are big enough to cover the distance between the two lines. |
| 148 | + (TWO_LINES, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 2), |
| 149 | + (TWO_LINES, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 2), |
| 150 | + # Grid dataset |
| 151 | + (GRID, TrivialCover(), 1, 1), |
| 152 | + # BallCover: components are expected to jump from many singletons sets |
| 153 | + # to a single one when the radius crosses the grid spacing. |
| 154 | + (GRID, BallCover(radius=0.01, metric="euclidean"), 100, 100), |
| 155 | + (GRID, BallCover(radius=0.2, metric="euclidean"), None, 1), |
| 156 | + # KNNCover: components are expected to merge when the number of neighbors |
| 157 | + # is more than the number of adjacent points in the grid. |
| 158 | + (GRID, KNNCover(neighbors=1, metric="euclidean"), 100, 100), |
| 159 | + (GRID, KNNCover(neighbors=10, metric="euclidean"), None, 1), |
| 160 | + (GRID, StandardCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1), |
| 161 | + (GRID, ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), 4, 1), |
121 | 162 | ], |
122 | 163 | ) |
123 | 164 | def test_cover(dataset, cover, num_charts, num_components): |
| 165 | + """ |
| 166 | + Test that the cover algorithm covers the dataset correctly, and that the |
| 167 | + number of charts and components is as expected. If num_charts or |
| 168 | + num_components is None, the test will not check that value. |
| 169 | + """ |
124 | 170 | charts = assert_coverage(dataset, cover) |
125 | 171 | if num_charts is not None: |
126 | 172 | assert len(charts) == num_charts |
127 | 173 | if num_components is not None: |
128 | 174 | assert count_components(charts) == num_components |
129 | 175 |
|
130 | 176 |
|
131 | | -def test_trivial_cover_grid(): |
132 | | - data = dataset_two_lines() |
133 | | - cover = TrivialCover() |
134 | | - charts = assert_coverage(data, cover) |
135 | | - assert 1 == len(charts) |
136 | | - num_components = count_components(charts) |
137 | | - assert 1 == num_components |
138 | | - |
139 | | - |
140 | | -def test_ball_cover_simple(): |
141 | | - data = [ |
142 | | - np.array([0.0, 1.0]), |
143 | | - np.array([1.0, 0.0]), |
144 | | - np.array([0.0, 0.0]), |
145 | | - np.array([1.0, 1.0]), |
146 | | - ] |
147 | | - cover = BallCover(radius=1.1, metric="euclidean") |
148 | | - charts = assert_coverage(data, cover) |
149 | | - assert 2 == len(charts) |
150 | | - num_components = count_components(charts) |
151 | | - assert 1 == num_components |
152 | | - |
153 | | - |
154 | | -def test_ball_cover_random(): |
155 | | - data = dataset_random(dim=2, num=10) |
156 | | - cover = BallCover(radius=0.2, metric="euclidean") |
157 | | - assert_coverage(data, cover) |
158 | | - |
159 | | - |
160 | | -def test_ball_cover_two_lines(): |
161 | | - data = dataset_two_lines() |
162 | | - cover = BallCover(radius=0.2, metric="euclidean") |
163 | | - charts = assert_coverage(data, cover) |
164 | | - num_components = count_components(charts) |
165 | | - assert 2 == num_components |
166 | | - |
167 | | - |
168 | | -def test_ball_cover_grid(): |
169 | | - data = dataset_grid(num=100) |
170 | | - cover = BallCover(radius=0.05, metric="euclidean") |
171 | | - charts = assert_coverage(data, cover) |
172 | | - num_components = count_components(charts) |
173 | | - assert 1 == num_components |
174 | | - |
175 | | - |
176 | | -def test_knn_cover_simple(): |
177 | | - data = [ |
178 | | - np.array([0.0, 1.0]), |
179 | | - np.array([1.1, 0.0]), |
180 | | - np.array([0.0, 0.0]), |
181 | | - np.array([1.1, 1.0]), |
182 | | - ] |
183 | | - cover = KNNCover(neighbors=2, metric="euclidean") |
184 | | - charts = assert_coverage(data, cover) |
185 | | - assert 2 == len(charts) |
186 | | - |
187 | | - |
188 | | -def test_knn_cover_two_lines(): |
189 | | - data = dataset_two_lines() |
190 | | - cover = KNNCover(neighbors=10, metric="euclidean") |
191 | | - charts = assert_coverage(data, cover) |
192 | | - num_components = count_components(charts) |
193 | | - assert 2 == num_components |
194 | | - |
195 | | - |
196 | | -def test_knn_cover_grid(): |
197 | | - data = dataset_grid(num=100) |
198 | | - cover = KNNCover(neighbors=10, metric="euclidean") |
199 | | - charts = assert_coverage(data, cover) |
200 | | - num_components = count_components(charts) |
201 | | - assert 1 == num_components |
202 | | - |
203 | | - |
204 | | -def test_cubical_cover_simple(): |
205 | | - data = [ |
206 | | - np.array([0.0, 1.0]), |
207 | | - np.array([1.1, 0.0]), |
208 | | - np.array([0.0, 0.0]), |
209 | | - np.array([1.1, 1.0]), |
210 | | - ] |
211 | | - cover = CubicalCover(n_intervals=2, overlap_frac=0.5) |
212 | | - charts = list(cover.apply(data)) |
213 | | - assert 4 == len(charts) |
214 | | - |
215 | | - |
216 | | -def test_cubical_cover_random(): |
217 | | - data = dataset_random(dim=2, num=100) |
218 | | - cover = CubicalCover(n_intervals=5, overlap_frac=0.1) |
219 | | - assert_coverage(data, cover) |
220 | | - |
221 | | - |
222 | | -def test_cubical_cover_params(): |
223 | | - cover = CubicalCover(n_intervals=2, overlap_frac=0.5) |
| 177 | +@pytest.mark.parametrize( |
| 178 | + "cover, params", |
| 179 | + [ |
| 180 | + (TrivialCover(), {}), |
| 181 | + ( |
| 182 | + BallCover(radius=0.2, metric="euclidean"), |
| 183 | + {"radius": 0.21, "metric": "euclidean"}, |
| 184 | + ), |
| 185 | + ( |
| 186 | + KNNCover(neighbors=10, metric="euclidean"), |
| 187 | + {"neighbors": 13, "metric": "euclidean"}, |
| 188 | + ), |
| 189 | + ( |
| 190 | + StandardCubicalCover(n_intervals=2, overlap_frac=0.5), |
| 191 | + {"n_intervals": 4, "overlap_frac": 0.145}, |
| 192 | + ), |
| 193 | + ( |
| 194 | + ProximityCubicalCover(n_intervals=2, overlap_frac=0.5), |
| 195 | + {"n_intervals": 4, "overlap_frac": 0.145}, |
| 196 | + ), |
| 197 | + ( |
| 198 | + CubicalCover(n_intervals=2, overlap_frac=0.5, algorithm="standard"), |
| 199 | + {"n_intervals": 4, "overlap_frac": 0.145, "algorithm": "proximity"}, |
| 200 | + ), |
| 201 | + ], |
| 202 | +) |
| 203 | +def test_params(cover, params): |
| 204 | + """ |
| 205 | + Test that the cover can get and set parameters correctly. |
| 206 | + """ |
| 207 | + cover.set_params(**params) |
224 | 208 | params = cover.get_params(deep=True) |
225 | | - assert 2 == params["n_intervals"] |
226 | | - assert 0.5 == params["overlap_frac"] |
227 | | - |
228 | | - |
229 | | -def test_standard_cover_simple(): |
230 | | - data = [ |
231 | | - np.array([0.0, 1.0]), |
232 | | - np.array([1.1, 0.0]), |
233 | | - np.array([0.0, 0.0]), |
234 | | - np.array([1.1, 1.0]), |
235 | | - ] |
236 | | - cover = CubicalCover( |
237 | | - n_intervals=2, |
238 | | - overlap_frac=0.5, |
239 | | - algorithm="standard", |
240 | | - ) |
241 | | - charts = list(cover.apply(data)) |
242 | | - assert 4 == len(charts) |
| 209 | + for k, v in params.items(): |
| 210 | + assert params[k] == v |
0 commit comments