Skip to content

Commit 7ef70ed

Browse files
authored
[ENH] Use np.argpartition for efficient top-k selection instead of np.argsort (#2805)
* use argpartition for efficient selection instead of argsort * made the required changes
1 parent d140fe9 commit 7ef70ed

File tree

1 file changed

+35
-23
lines changed

1 file changed

+35
-23
lines changed

aeon/similarity_search/series/_commons.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -137,36 +137,48 @@ def _extract_top_k_from_dist_profile(
137137
top_k_distances = np.full(k, np.inf, dtype=np.float64)
138138
ub = np.full(k, np.inf)
139139
lb = np.full(k, -1.0)
140-
# Could be optimized by using argpartition
141-
sorted_indexes = np.argsort(dist_profile)
140+
141+
remaining_indices = np.arange(len(dist_profile))
142+
mask = np.full(len(dist_profile), True)
142143
_current_k = 0
144+
143145
if not allow_trivial_matches:
144-
_current_j = 0
145-
# Until we extract k value or explore all the array or until dist is > threshold
146-
while _current_k < k and _current_j < len(sorted_indexes):
147-
# if we didn't insert anything or there is a conflict in lb/ub
148-
if _current_k > 0 and np.any(
149-
(sorted_indexes[_current_j] >= lb[:_current_k])
150-
& (sorted_indexes[_current_j] <= ub[:_current_k])
151-
):
152-
pass
153-
else:
154-
_idx = sorted_indexes[_current_j]
155-
if dist_profile[_idx] <= threshold:
156-
top_k_indexes[_current_k] = _idx
157-
top_k_distances[_current_k] = dist_profile[_idx]
158-
ub[_current_k] = min(
159-
top_k_indexes[_current_k] + exclusion_size,
160-
len(dist_profile),
161-
)
162-
lb[_current_k] = max(top_k_indexes[_current_k] - exclusion_size, 0)
146+
while _current_k < k and np.any(mask):
147+
available_indices = remaining_indices[mask]
148+
search_k = min(k, len(available_indices))
149+
if search_k == 0:
150+
break
151+
partitioned = available_indices[
152+
np.argpartition(dist_profile[available_indices], search_k - 1)[
153+
:search_k
154+
]
155+
]
156+
sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])]
157+
158+
for idx in sorted_indexes:
159+
if _current_k > 0 and np.any(
160+
(idx >= lb[:_current_k]) & (idx <= ub[:_current_k])
161+
):
162+
continue
163+
164+
if dist_profile[idx] <= threshold:
165+
top_k_indexes[_current_k] = idx
166+
top_k_distances[_current_k] = dist_profile[idx]
167+
ub[_current_k] = min(idx + exclusion_size, len(dist_profile))
168+
lb[_current_k] = max(idx - exclusion_size, 0)
163169
_current_k += 1
164170
else:
165171
break
166-
_current_j += 1
172+
173+
if _current_k == k:
174+
break
175+
176+
mask[sorted_indexes] = False
167177
else:
168178
_current_k += min(k, len(dist_profile))
169-
dist_profile = dist_profile[sorted_indexes[:_current_k]]
179+
partitioned = np.argpartition(dist_profile, k)[:k]
180+
sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])]
181+
dist_profile = dist_profile[sorted_indexes]
170182
dist_profile = dist_profile[dist_profile <= threshold]
171183
_current_k = len(dist_profile)
172184

0 commit comments

Comments
 (0)