From 8a7395731650b49ba5a9552d3695bd5d70b19d76 Mon Sep 17 00:00:00 2001 From: Kaustbh Date: Wed, 14 May 2025 21:42:44 +0530 Subject: [PATCH 1/2] use argpartition for efficient selection instead of argsort --- aeon/similarity_search/series/_commons.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/aeon/similarity_search/series/_commons.py b/aeon/similarity_search/series/_commons.py index 646c38e5ff..7041f628f6 100644 --- a/aeon/similarity_search/series/_commons.py +++ b/aeon/similarity_search/series/_commons.py @@ -137,8 +137,10 @@ def _extract_top_k_from_dist_profile( top_k_distances = np.full(k, np.inf, dtype=np.float64) ub = np.full(k, np.inf) lb = np.full(k, -1.0) - # Could be optimized by using argpartition - sorted_indexes = np.argsort(dist_profile) + + k = min(k, len(dist_profile)) + partitioned = np.argpartition(dist_profile, k)[:k] + sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] _current_k = 0 if not allow_trivial_matches: _current_j = 0 @@ -165,8 +167,7 @@ def _extract_top_k_from_dist_profile( break _current_j += 1 else: - _current_k += min(k, len(dist_profile)) - dist_profile = dist_profile[sorted_indexes[:_current_k]] + dist_profile = dist_profile[sorted_indexes] dist_profile = dist_profile[dist_profile <= threshold] _current_k = len(dist_profile) From 9163cc2b802a4f401aa1e7d6916e2ddb5ece12ff Mon Sep 17 00:00:00 2001 From: Kaustbh Date: Fri, 23 May 2025 22:20:33 +0530 Subject: [PATCH 2/2] made the required changes --- aeon/similarity_search/series/_commons.py | 57 ++++++++++++++--------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/aeon/similarity_search/series/_commons.py b/aeon/similarity_search/series/_commons.py index b486e7b1c8..0e1df4235f 100644 --- a/aeon/similarity_search/series/_commons.py +++ b/aeon/similarity_search/series/_commons.py @@ -138,35 +138,46 @@ def _extract_top_k_from_dist_profile( ub = np.full(k, np.inf) lb = np.full(k, -1.0) - k = min(k, len(dist_profile)) - partitioned = np.argpartition(dist_profile, k)[:k] - sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] + remaining_indices = np.arange(len(dist_profile)) + mask = np.full(len(dist_profile), True) _current_k = 0 + if not allow_trivial_matches: - _current_j = 0 - # Until we extract k value or explore all the array or until dist is > threshold - while _current_k < k and _current_j < len(sorted_indexes): - # if we didn't insert anything or there is a conflict in lb/ub - if _current_k > 0 and np.any( - (sorted_indexes[_current_j] >= lb[:_current_k]) - & (sorted_indexes[_current_j] <= ub[:_current_k]) - ): - pass - else: - _idx = sorted_indexes[_current_j] - if dist_profile[_idx] <= threshold: - top_k_indexes[_current_k] = _idx - top_k_distances[_current_k] = dist_profile[_idx] - ub[_current_k] = min( - top_k_indexes[_current_k] + exclusion_size, - len(dist_profile), - ) - lb[_current_k] = max(top_k_indexes[_current_k] - exclusion_size, 0) + while _current_k < k and np.any(mask): + available_indices = remaining_indices[mask] + search_k = min(k, len(available_indices)) + if search_k == 0: + break + partitioned = available_indices[ + np.argpartition(dist_profile[available_indices], search_k - 1)[ + :search_k + ] + ] + sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] + + for idx in sorted_indexes: + if _current_k > 0 and np.any( + (idx >= lb[:_current_k]) & (idx <= ub[:_current_k]) + ): + continue + + if dist_profile[idx] <= threshold: + top_k_indexes[_current_k] = idx + top_k_distances[_current_k] = dist_profile[idx] + ub[_current_k] = min(idx + exclusion_size, len(dist_profile)) + lb[_current_k] = max(idx - exclusion_size, 0) _current_k += 1 else: break - _current_j += 1 + + if _current_k == k: + break + + mask[sorted_indexes] = False else: + _current_k += min(k, len(dist_profile)) + partitioned = np.argpartition(dist_profile, k)[:k] + sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] dist_profile = dist_profile[sorted_indexes] dist_profile = dist_profile[dist_profile <= threshold] _current_k = len(dist_profile)