@@ -137,36 +137,48 @@ def _extract_top_k_from_dist_profile(
137
137
top_k_distances = np .full (k , np .inf , dtype = np .float64 )
138
138
ub = np .full (k , np .inf )
139
139
lb = np .full (k , - 1.0 )
140
- # Could be optimized by using argpartition
141
- sorted_indexes = np .argsort (dist_profile )
140
+
141
+ remaining_indices = np .arange (len (dist_profile ))
142
+ mask = np .full (len (dist_profile ), True )
142
143
_current_k = 0
144
+
143
145
if not allow_trivial_matches :
144
- _current_j = 0
145
- # Until we extract k value or explore all the array or until dist is > threshold
146
- while _current_k < k and _current_j < len (sorted_indexes ):
147
- # if we didn't insert anything or there is a conflict in lb/ub
148
- if _current_k > 0 and np .any (
149
- (sorted_indexes [_current_j ] >= lb [:_current_k ])
150
- & (sorted_indexes [_current_j ] <= ub [:_current_k ])
151
- ):
152
- pass
153
- else :
154
- _idx = sorted_indexes [_current_j ]
155
- if dist_profile [_idx ] <= threshold :
156
- top_k_indexes [_current_k ] = _idx
157
- top_k_distances [_current_k ] = dist_profile [_idx ]
158
- ub [_current_k ] = min (
159
- top_k_indexes [_current_k ] + exclusion_size ,
160
- len (dist_profile ),
161
- )
162
- lb [_current_k ] = max (top_k_indexes [_current_k ] - exclusion_size , 0 )
146
+ while _current_k < k and np .any (mask ):
147
+ available_indices = remaining_indices [mask ]
148
+ search_k = min (k , len (available_indices ))
149
+ if search_k == 0 :
150
+ break
151
+ partitioned = available_indices [
152
+ np .argpartition (dist_profile [available_indices ], search_k - 1 )[
153
+ :search_k
154
+ ]
155
+ ]
156
+ sorted_indexes = partitioned [np .argsort (dist_profile [partitioned ])]
157
+
158
+ for idx in sorted_indexes :
159
+ if _current_k > 0 and np .any (
160
+ (idx >= lb [:_current_k ]) & (idx <= ub [:_current_k ])
161
+ ):
162
+ continue
163
+
164
+ if dist_profile [idx ] <= threshold :
165
+ top_k_indexes [_current_k ] = idx
166
+ top_k_distances [_current_k ] = dist_profile [idx ]
167
+ ub [_current_k ] = min (idx + exclusion_size , len (dist_profile ))
168
+ lb [_current_k ] = max (idx - exclusion_size , 0 )
163
169
_current_k += 1
164
170
else :
165
171
break
166
- _current_j += 1
172
+
173
+ if _current_k == k :
174
+ break
175
+
176
+ mask [sorted_indexes ] = False
167
177
else :
168
178
_current_k += min (k , len (dist_profile ))
169
- dist_profile = dist_profile [sorted_indexes [:_current_k ]]
179
+ partitioned = np .argpartition (dist_profile , k )[:k ]
180
+ sorted_indexes = partitioned [np .argsort (dist_profile [partitioned ])]
181
+ dist_profile = dist_profile [sorted_indexes ]
170
182
dist_profile = dist_profile [dist_profile <= threshold ]
171
183
_current_k = len (dist_profile )
172
184
0 commit comments