@@ -1130,6 +1130,10 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1130
1130
uint32_t hops = 0 ;
1131
1131
uint32_t cmps = 0 ;
1132
1132
1133
+ // Pre-size scratch vectors for batch operations to avoid reallocations
1134
+ size_t estimated_batch_size = std::max (init_ids.size (), (size_t )(Lsize * 2 ));
1135
+ scratch->resize_jaccard_scratch_for_batch (estimated_batch_size);
1136
+
1133
1137
float *pq_dists = nullptr ;
1134
1138
1135
1139
if (print_qstats)
@@ -1166,7 +1170,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1166
1170
uint32_t candidates_filtered = 0 ;
1167
1171
1168
1172
// STEP 1: Collect all valid initial candidates first
1169
- std::vector<uint32_t > valid_init_candidates;
1173
+ std::vector<uint32_t > &valid_init_candidates = scratch->valid_init_candidates_scratch ();
1174
+ valid_init_candidates.clear ();
1170
1175
valid_init_candidates.reserve (init_ids.size ());
1171
1176
1172
1177
for (auto id : init_ids)
@@ -1202,11 +1207,13 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1202
1207
}
1203
1208
1204
1209
// STEP 2: BATCH PROCESS all initial candidates' filter scores at once
1205
- std::vector<float > init_penalties;
1210
+ std::vector<float > &init_penalties = scratch->init_penalties_scratch ();
1211
+ init_penalties.clear ();
1206
1212
if (use_filter && !valid_init_candidates.empty ())
1207
1213
{
1208
1214
auto jaccard_start = std::chrono::high_resolution_clock::now ();
1209
- std::vector<float > jaccard_similarities;
1215
+ std::vector<float > &jaccard_similarities = scratch->jaccard_similarities_scratch ();
1216
+ jaccard_similarities.clear ();
1210
1217
calculate_jaccard_similarity_batch (valid_init_candidates, filter_labels, jaccard_similarities);
1211
1218
std::chrono::duration<double > jaccard_diff = std::chrono::high_resolution_clock::now () - jaccard_start;
1212
1219
curr_jaccard_time += jaccard_diff.count ();
@@ -1425,14 +1432,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1425
1432
if (use_filter && !id_scratch.empty ())
1426
1433
{
1427
1434
auto jaccard_start = std::chrono::high_resolution_clock::now ();
1428
- std::vector<float > jaccard_similarities;
1435
+ std::vector<float > &jaccard_similarities = scratch->jaccard_similarities_scratch ();
1436
+ jaccard_similarities.clear ();
1429
1437
calculate_jaccard_similarity_batch (id_scratch, filter_labels, jaccard_similarities);
1430
1438
std::chrono::duration<double > jaccard_diff = std::chrono::high_resolution_clock::now () - jaccard_start;
1431
1439
curr_jaccard_time += jaccard_diff.count ();
1432
1440
1433
1441
// Convert similarities to penalties and apply filtering
1434
1442
res_vec.reserve (id_scratch.size ());
1435
- std::vector<uint32_t > filtered_ids;
1443
+ std::vector<uint32_t > &filtered_ids = scratch->filtered_ids_scratch ();
1444
+ filtered_ids.clear ();
1436
1445
filtered_ids.reserve (id_scratch.size ());
1437
1446
1438
1447
for (size_t i = 0 ; i < id_scratch.size (); ++i)
0 commit comments