Skip to content

Commit 25f7689

Browse files
authored
Merge pull request #249 from lucasimi/improve-testing-cover
Improve testing cover
2 parents b921c63 + edb279b commit 25f7689

21 files changed

+859
-730
lines changed

src/tdamapper/_common.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def deprecated(msg: str) -> Callable[..., Any]:
2727
"""
2828

2929
def deprecated_func(func: Callable[..., Any]) -> Callable[..., Any]:
30-
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
30+
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
3131
warnings.warn(msg, DeprecationWarning, stacklevel=2)
3232
return func(*args, **kwargs)
3333

@@ -179,10 +179,12 @@ def __repr__(self) -> str:
179179
obj_noargs = type(self)()
180180
args_repr = []
181181
for k, v in self.__dict__.items():
182+
if not self._is_param_public(k):
183+
continue
182184
v_default = getattr(obj_noargs, k)
183185
v_default_repr = repr(v_default)
184186
v_repr = repr(v)
185-
if self._is_param_public(k) and not v_repr == v_default_repr:
187+
if not v_repr == v_default_repr:
186188
args_repr.append(f"{k}={v_repr}")
187189
return f"{self.__class__.__name__}({', '.join(args_repr)})"
188190

@@ -211,7 +213,7 @@ def profile(n_lines: int = 10) -> Callable[..., Any]:
211213
"""
212214

213215
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
214-
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
216+
def wrapper(*args: list[Any], **kwargs: Any) -> Any:
215217
profiler = cProfile.Profile()
216218
profiler.enable()
217219
result = func(*args, **kwargs)

src/tdamapper/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def _umap(X: NDArray[np.float_]) -> NDArray[np.float_]:
202202

203203

204204
def run_mapper(
205-
df: pd.DataFrame, **kwargs: dict[str, Any]
205+
df: pd.DataFrame, **kwargs: Any
206206
) -> Optional[tuple[nx.Graph, pd.DataFrame]]:
207207
"""
208208
Runs the Mapper algorithm on the provided DataFrame and returns the Mapper
@@ -301,7 +301,7 @@ def create_mapper_figure(
301301
df_y: pd.DataFrame,
302302
df_target: pd.DataFrame,
303303
mapper_graph: nx.Graph,
304-
**kwargs: dict[str, Any],
304+
**kwargs: Any,
305305
) -> go.Figure:
306306
"""
307307
Renders the Mapper graph as a Plotly figure.

src/tdamapper/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,24 @@ class TrivialCover(ParamsMixin, Generic[T]):
279279
dataset.
280280
"""
281281

282+
def fit(self, X: ArrayRead[T]) -> TrivialCover[T]:
283+
"""
284+
Fit the cover algorithm to the data.
285+
286+
:param X: A dataset of n points. Ignored.
287+
:return: self
288+
"""
289+
return self
290+
282291
def apply(self, X: ArrayRead[T]) -> Iterator[list[int]]:
283292
"""
284293
Covers the dataset with a single open set.
285294
286295
:param X: A dataset of n points.
287296
:return: A generator of lists of ids.
288297
"""
289-
yield list(range(0, len(X)))
298+
if len(X) > 0:
299+
yield list(range(0, len(X)))
290300

291301

292302
class FailSafeClustering(ParamsMixin, Generic[T]):

src/tdamapper/cover.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ def fit(self, X: ArrayRead[NDArray[np.float_]]) -> BaseCubicalCover:
377377
:param X: A dataset of n points.
378378
:return: The object itself.
379379
"""
380+
if len(X) == 0:
381+
return self
380382
X_ = np.asarray(X).reshape(len(X), -1).astype(float)
381383
if self.overlap_frac is None:
382384
dim = 1 if X_.ndim == 1 else X_.shape[1]

src/tdamapper/learn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,19 @@ def fit(
159159
"""
160160
y_ = X if y is None else y
161161
X, y_ = self._validate_X_y(X, y_)
162-
self._cover = TrivialCover() if self.cover is None else self.cover
163-
self._clustering = (
164-
TrivialClustering() if self.clustering is None else self.clustering
165-
)
162+
self._cover = TrivialCover()
163+
if self.cover is not None:
164+
self._cover = clone(self.cover)
165+
self._clustering = TrivialClustering()
166+
if self.clustering is not None:
167+
self._clustering = clone(self.clustering)
166168
self._verbose = self.verbose
167169
self._failsafe = self.failsafe
168170
if self._failsafe:
169171
self._clustering = FailSafeClustering(
170172
clustering=self._clustering,
171173
verbose=self._verbose,
172174
)
173-
self._cover = clone(self._cover)
174-
self._clustering = clone(self._clustering)
175175
self._n_jobs = self.n_jobs
176176
self.graph_ = mapper_graph(
177177
X,

src/tdamapper/utils/heap.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
This module implements a max-heap data structure that allows for efficient
3+
retrieval and removal of the maximum element. The heap supports adding
4+
elements, retrieving the maximum element, and removing the maximum element
5+
while maintaining the heap property.
6+
"""
7+
18
from __future__ import annotations
29

310
from typing import Generic, Iterator, Optional, Protocol, TypeVar
@@ -16,6 +23,9 @@ def _parent(i: int) -> int:
1623

1724

1825
class Comparable(Protocol):
26+
"""
27+
Protocol for comparison methods required for a key in the heap.
28+
"""
1929

2030
def __lt__(self: K, other: K) -> bool: ...
2131

@@ -32,6 +42,14 @@ def __ge__(self: K, other: K) -> bool: ...
3242

3343

3444
class _HeapNode(Generic[K, V]):
45+
"""
46+
A node in the heap that holds a key-value pair.
47+
48+
The key is used for comparison, and the value is stored alongside it.
49+
50+
:param key: The key used for comparison.
51+
:param value: The value associated with the key.
52+
"""
3553

3654
_key: K
3755
_value: V
@@ -41,6 +59,11 @@ def __init__(self, key: K, value: V) -> None:
4159
self._value = value
4260

4361
def get(self) -> tuple[K, V]:
62+
"""
63+
Returns the key-value pair stored in the node.
64+
65+
:return: A tuple containing the key and value.
66+
"""
4467
return self._key, self._value
4568

4669
def __lt__(self, other: _HeapNode[K, V]) -> bool:
@@ -57,6 +80,12 @@ def __ge__(self, other: _HeapNode[K, V]) -> bool:
5780

5881

5982
class MaxHeap(Generic[K, V]):
83+
"""
84+
A max-heap implementation that allows for efficient retrieval of the
85+
maximum element. This heap supports adding elements, retrieving the maximum
86+
element, and removing the maximum element while maintaining the heap
87+
property.
88+
"""
6089

6190
_heap: list[_HeapNode[K, V]]
6291
_iter: Iterator[_HeapNode[K, V]]
@@ -75,12 +104,32 @@ def __next__(self) -> tuple[K, V]:
75104
def __len__(self) -> int:
76105
return len(self._heap)
77106

107+
def is_empty(self) -> bool:
108+
"""
109+
Check if the heap is empty.
110+
111+
:return: True if the heap is empty, False otherwise.
112+
"""
113+
return len(self._heap) == 0
114+
78115
def top(self) -> Optional[tuple[K, V]]:
116+
"""
117+
Returns the maximum element in the heap without removing it.
118+
119+
:return: A tuple containing the key and value of the maximum element,
120+
or None if the heap is empty.
121+
"""
79122
if not self._heap:
80123
return None
81124
return self._heap[0].get()
82125

83126
def pop(self) -> Optional[tuple[K, V]]:
127+
"""
128+
Removes and returns the maximum element from the heap.
129+
130+
:return: A tuple containing the key and value of the maximum element,
131+
or None if the heap is empty.
132+
"""
84133
if not self._heap:
85134
return None
86135
max_val = self._heap[0]
@@ -89,8 +138,14 @@ def pop(self) -> Optional[tuple[K, V]]:
89138
self._bubble_down()
90139
return max_val.get()
91140

92-
def add(self, key: K, val: V) -> None:
93-
self._heap.append(_HeapNode(key, val))
141+
def add(self, key: K, value: V) -> None:
142+
"""
143+
Adds a new key-value pair to the heap.
144+
145+
:param key: The key used for comparison.
146+
:param value: The value associated with the key.
147+
"""
148+
self._heap.append(_HeapNode(key, value))
94149
self._bubble_up()
95150

96151
def _get_local_max(self, i: int) -> int:

src/tdamapper/utils/metrics.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_supported_metrics() -> list[MetricLiteral]:
5858
return list(get_args(MetricLiteral))
5959

6060

61-
def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
61+
def euclidean() -> Metric[Any]:
6262
"""
6363
Return the Euclidean distance function for vectors.
6464
@@ -70,7 +70,7 @@ def euclidean(**kwargs: dict[str, Any]) -> Metric[Any]:
7070
return _metrics.euclidean
7171

7272

73-
def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
73+
def manhattan() -> Metric[Any]:
7474
"""
7575
Return the Manhattan distance function for vectors.
7676
@@ -82,7 +82,7 @@ def manhattan(**kwargs: dict[str, Any]) -> Metric[Any]:
8282
return _metrics.manhattan
8383

8484

85-
def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
85+
def chebyshev() -> Metric[Any]:
8686
"""
8787
Return the Chebyshev distance function for vectors.
8888
@@ -94,7 +94,7 @@ def chebyshev(**kwargs: dict[str, Any]) -> Metric[Any]:
9494
return _metrics.chebyshev
9595

9696

97-
def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
97+
def minkowski(p: Union[int, float]) -> Metric[Any]:
9898
"""
9999
Return the Minkowski distance function for order p on vectors.
100100
@@ -106,9 +106,6 @@ def minkowski(**kwargs: dict[str, Any]) -> Metric[Any]:
106106
:param p: The order of the Minkowski distance.
107107
:return: The Minkowski distance function.
108108
"""
109-
p = kwargs.get("p", 2)
110-
if not isinstance(p, (int, float)):
111-
raise TypeError("p must be an integer or a float")
112109
if p == 1:
113110
return manhattan()
114111
if p == 2:
@@ -122,7 +119,7 @@ def dist(x: Any, y: Any) -> float:
122119
return dist
123120

124121

125-
def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
122+
def cosine() -> Metric[Any]:
126123
"""
127124
Return the cosine distance function for vectors.
128125
@@ -145,9 +142,7 @@ def cosine(**kwargs: dict[str, Any]) -> Metric[Any]:
145142
return _metrics.cosine
146143

147144

148-
def get_metric(
149-
metric: Union[MetricLiteral, Metric[Any]], **kwargs: dict[str, Any]
150-
) -> Metric[Any]:
145+
def get_metric(metric: Union[MetricLiteral, Metric[Any]], **kwargs: Any) -> Metric[Any]:
151146
"""
152147
Return a distance function based on the specified string or callable.
153148

src/tdamapper/utils/unionfind.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,36 @@
1+
"""
2+
This module implements a Union-Find data structure that supports union and
3+
find operations.
4+
"""
5+
16
from typing import Any, Iterable
27

38

49
class UnionFind:
10+
"""
11+
A Union-Find data structure that supports union and find operations.
12+
13+
This implementation uses path compression for efficient find operations
14+
and union by size to keep the tree flat. It allows for efficient
15+
determination of connected components in a set of elements.
16+
17+
:param X: An iterable of elements to initialize the Union-Find structure.
18+
"""
519

620
_parent: dict[Any, Any]
721
_size: dict[Any, int]
822

9-
def __init__(self, X: Iterable[Any]):
10-
self._parent = {x: x for x in X}
11-
self._size = {x: 1 for x in X}
23+
def __init__(self, items: Iterable[Any]):
24+
self._parent = {x: x for x in items}
25+
self._size = {x: 1 for x in items}
1226

1327
def find(self, x: Any) -> Any:
28+
"""
29+
Finds the class of an element, applying path compression.
30+
31+
:param x: The element to find the class of.
32+
:return: The representative of the class containing x.
33+
"""
1434
root = x
1535
while root != self._parent[root]:
1636
root = self._parent[root]
@@ -22,6 +42,13 @@ def find(self, x: Any) -> Any:
2242
return root
2343

2444
def union(self, x: Any, y: Any) -> Any:
45+
"""
46+
Unites the classes of two elements.
47+
48+
:param x: The first element.
49+
:param y: The second element.
50+
:return: The representative of the class after the union operation.
51+
"""
2552
x, y = self.find(x), self.find(y)
2653
if x != y:
2754
x_size, y_size = self._size[x], self._size[y]

src/tdamapper/utils/vptree_flat/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def build(self) -> VPArray[T]:
9191
9292
:return: A tuple containing the constructed vp-tree and the VPArray.
9393
"""
94-
self._build_iter()
94+
if self._array.size() > 0:
95+
self._build_iter()
9596
return self._array
9697

9798
def _build_iter(self) -> None:

src/tdamapper/utils/vptree_hier/builder.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,15 @@ def build(self) -> tuple[Tree[T], VPArray[T]]:
9494
9595
:return: A tuple containing the constructed vp-tree and the VPArray.
9696
"""
97-
tree = self._build_rec(0, self._array.size())
97+
if self._array.size() > 0:
98+
tree = self._build_rec(0, self._array.size())
99+
else:
100+
tree = Leaf(0, 0)
98101
return tree, self._array
99102

100103
def _build_rec(self, start: int, end: int) -> Tree[T]:
104+
if end - start <= self._leaf_capacity:
105+
return Leaf(start, end)
101106
mid = _mid(start, end)
102107
self._update(start, end)
103108
v_point = self._array.get_point(start)
@@ -106,7 +111,7 @@ def _build_rec(self, start: int, end: int) -> Tree[T]:
106111
self._array.set_distance(start, v_radius)
107112
left: Tree[T]
108113
right: Tree[T]
109-
if (end - start <= 2 * self._leaf_capacity) or (v_radius <= self._leaf_radius):
114+
if v_radius <= self._leaf_radius:
110115
left = Leaf(start + 1, mid)
111116
right = Leaf(mid, end)
112117
else:

0 commit comments

Comments
 (0)