diff --git a/aeon/similarity_search/series/_commons.py b/aeon/similarity_search/series/_commons.py index fa3346df11..0e1df4235f 100644 --- a/aeon/similarity_search/series/_commons.py +++ b/aeon/similarity_search/series/_commons.py @@ -137,36 +137,48 @@ def _extract_top_k_from_dist_profile( top_k_distances = np.full(k, np.inf, dtype=np.float64) ub = np.full(k, np.inf) lb = np.full(k, -1.0) - # Could be optimized by using argpartition - sorted_indexes = np.argsort(dist_profile) + + remaining_indices = np.arange(len(dist_profile)) + mask = np.full(len(dist_profile), True) _current_k = 0 + if not allow_trivial_matches: - _current_j = 0 - # Until we extract k value or explore all the array or until dist is > threshold - while _current_k < k and _current_j < len(sorted_indexes): - # if we didn't insert anything or there is a conflict in lb/ub - if _current_k > 0 and np.any( - (sorted_indexes[_current_j] >= lb[:_current_k]) - & (sorted_indexes[_current_j] <= ub[:_current_k]) - ): - pass - else: - _idx = sorted_indexes[_current_j] - if dist_profile[_idx] <= threshold: - top_k_indexes[_current_k] = _idx - top_k_distances[_current_k] = dist_profile[_idx] - ub[_current_k] = min( - top_k_indexes[_current_k] + exclusion_size, - len(dist_profile), - ) - lb[_current_k] = max(top_k_indexes[_current_k] - exclusion_size, 0) + while _current_k < k and np.any(mask): + available_indices = remaining_indices[mask] + search_k = min(k, len(available_indices)) + if search_k == 0: + break + partitioned = available_indices[ + np.argpartition(dist_profile[available_indices], search_k - 1)[ + :search_k + ] + ] + sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] + + for idx in sorted_indexes: + if _current_k > 0 and np.any( + (idx >= lb[:_current_k]) & (idx <= ub[:_current_k]) + ): + continue + + if dist_profile[idx] <= threshold: + top_k_indexes[_current_k] = idx + top_k_distances[_current_k] = dist_profile[idx] + ub[_current_k] = min(idx + exclusion_size, len(dist_profile)) + lb[_current_k] = max(idx - exclusion_size, 0) _current_k += 1 else: break - _current_j += 1 + + if _current_k == k: + break + + mask[sorted_indexes] = False else: _current_k += min(k, len(dist_profile)) - dist_profile = dist_profile[sorted_indexes[:_current_k]] + partitioned = np.argpartition(dist_profile, k)[:k] + sorted_indexes = partitioned[np.argsort(dist_profile[partitioned])] + dist_profile = dist_profile[sorted_indexes] dist_profile = dist_profile[dist_profile <= threshold] _current_k = len(dist_profile)