@@ -1118,15 +1118,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1118
1118
InMemQueryScratch<T> *scratch, const uint32_t Lsize, const std::vector<uint32_t > &init_ids, bool use_filter,
1119
1119
const std::vector<std::vector<LabelT>> &filter_labels, bool search_invocation)
1120
1120
{
1121
-
1122
- /* for (auto &x : filter_labels) {
1123
- std::cout<<"(";
1124
- for (auto &y : x) {
1125
- std::cout<<y<<"|";
1126
- }
1127
- std::cout<<")&";
1128
- }*/
1129
-
1130
1121
std::vector<Neighbor> &expanded_nodes = scratch->pool ();
1131
1122
NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes ();
1132
1123
best_L_nodes.reserve (Lsize);
@@ -1137,7 +1128,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1137
1128
assert (id_scratch.size () == 0 );
1138
1129
T *aligned_query = scratch->aligned_query ();
1139
1130
uint32_t hops = 0 ;
1140
- uint32_t cmps = 0 ;
1131
+ uint32_t cmps = 0 ;
1141
1132
1142
1133
float *pq_dists = nullptr ;
1143
1134
@@ -1170,20 +1161,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1170
1161
_pq_data_store->get_distance (scratch->aligned_query (), ids, dists_out, scratch);
1171
1162
};
1172
1163
1173
- // Initialize the candidate pool with starting points - BATCH PROCESS ALL INITIAL CANDIDATES
1174
- // Initialize the candidate pool with starting points - BATCH PROCESS ALL INITIAL CANDIDATES
1175
- // std::cout << "DEBUG: Starting with " << init_ids.size() << " initial candidates" << std::endl;
1164
+ // Initialize the candidate pool with starting points
1176
1165
uint32_t candidates_added = 0 ;
1177
1166
uint32_t candidates_filtered = 0 ;
1178
1167
1179
1168
// STEP 1: Collect all valid initial candidates first
1180
1169
std::vector<uint32_t > valid_init_candidates;
1181
1170
valid_init_candidates.reserve (init_ids.size ());
1182
1171
1183
- // STEP 1: Collect all valid initial candidates first
1184
- std::vector<uint32_t > valid_init_candidates;
1185
- valid_init_candidates.reserve (init_ids.size ());
1186
-
1187
1172
for (auto id : init_ids)
1188
1173
{
1189
1174
if (id >= _max_points + _num_frozen_pts)
@@ -1237,65 +1222,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1237
1222
init_penalties.resize (valid_init_candidates.size (), 0 .0f );
1238
1223
}
1239
1224
1240
- // STEP 3: Process all candidates with their pre-computed penalties
1241
- for (size_t i = 0 ; i < valid_init_candidates.size (); ++i)
1242
- {
1243
- uint32_t id = valid_init_candidates[i];
1244
- float res = init_penalties[i];
1245
-
1246
- candidates_added++;
1247
- if (fast_iterate)
1248
- {
1249
- inserted_into_pool_bs.add (id);
1250
- }
1251
- else
1252
- {
1253
- inserted_into_pool_rs.insert (id);
1254
- }
1255
- // Pre-filter for non-search invocations
1256
- if (use_filter && !search_invocation)
1257
- {
1258
- uint32_t common_count = detect_common_filters (id, search_invocation, filter_labels);
1259
- if (common_count < min_inter_size)
1260
- {
1261
- candidates_filtered++;
1262
- if (print_qstats)
1263
- {
1264
- std::ofstream out (" query_stats.txt" , std::ios_base::app);
1265
- out << " FILTERED OUT: id " << id << " has only " << common_count << " common filters (need " << min_inter_size << " )" << std::endl;
1266
- out.close ();
1267
- }
1268
- continue ;
1269
- }
1270
- }
1271
-
1272
- if (is_not_visited (id))
1273
- {
1274
- valid_init_candidates.push_back (id);
1275
- }
1276
- }
1277
-
1278
- // STEP 2: BATCH PROCESS all initial candidates' filter scores at once
1279
- std::vector<float > init_penalties;
1280
- if (use_filter && !valid_init_candidates.empty ())
1281
- {
1282
- auto jaccard_start = std::chrono::high_resolution_clock::now ();
1283
- std::vector<float > jaccard_similarities;
1284
- calculate_jaccard_similarity_batch (valid_init_candidates, filter_labels, jaccard_similarities);
1285
- std::chrono::duration<double > jaccard_diff = std::chrono::high_resolution_clock::now () - jaccard_start;
1286
- curr_jaccard_time += jaccard_diff.count ();
1287
-
1288
- // Convert similarities to penalties
1289
- init_penalties.reserve (jaccard_similarities.size ());
1290
- for (float sim : jaccard_similarities) {
1291
- init_penalties.push_back (1 .0f - sim);
1292
- }
1293
- }
1294
- else
1295
- {
1296
- init_penalties.resize (valid_init_candidates.size (), 0 .0f );
1297
- }
1298
-
1299
1225
// STEP 3: Process all candidates with their pre-computed penalties
1300
1226
for (size_t i = 0 ; i < valid_init_candidates.size (); ++i)
1301
1227
{
@@ -1326,19 +1252,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1326
1252
Neighbor nn = Neighbor (id, distance);
1327
1253
best_L_nodes.insert (nn);
1328
1254
1329
- if (print_qstats)
1330
- {
1331
- std::ofstream out (" query_stats.txt" , std::ios_base::app);
1332
- out << " BATCH INIT: id " << id << " has penalty " << res << " and filters " ;
1333
- for (auto const &filter : _location_to_labels[id])
1334
- {
1335
- out << filter << " " ;
1336
- }
1337
- out << std::endl;
1338
- out.close ();
1339
- Neighbor nn = Neighbor (id, distance);
1340
- best_L_nodes.insert (nn);
1341
-
1342
1255
if (print_qstats)
1343
1256
{
1344
1257
std::ofstream out (" query_stats.txt" , std::ios_base::app);
@@ -1508,62 +1421,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1508
1421
id_scratch.push_back (id);
1509
1422
}
1510
1423
1511
- // BATCH PROCESSING: Process all Jaccard similarities at once
1512
- if (use_filter && !id_scratch.empty ())
1513
- {
1514
- auto jaccard_start = std::chrono::high_resolution_clock::now ();
1515
- std::vector<float > jaccard_similarities;
1516
- calculate_jaccard_similarity_batch (id_scratch, filter_labels, jaccard_similarities);
1517
- std::chrono::duration<double > jaccard_diff = std::chrono::high_resolution_clock::now () - jaccard_start;
1518
- curr_jaccard_time += jaccard_diff.count ();
1519
-
1520
- // Convert similarities to penalties and apply filtering
1521
- res_vec.reserve (id_scratch.size ());
1522
- std::vector<uint32_t > filtered_ids;
1523
- filtered_ids.reserve (id_scratch.size ());
1524
-
1525
- for (size_t i = 0 ; i < id_scratch.size (); ++i)
1526
- {
1527
- float penalty = 1 .0f - jaccard_similarities[i];
1528
-
1529
- // Optional: Filter out points with very high penalty during search
1530
- if (search_invocation && penalty > 0 .95f ) // 95% penalty threshold
1531
- continue ;
1532
-
1533
- res_vec.push_back (penalty);
1534
- filtered_ids.push_back (id_scratch[i]);
1535
-
1536
- if (print_qstats)
1537
- {
1538
- std::ofstream out (" query_stats.txt" , std::ios_base::app);
1539
- out << " BATCH processed nbr " << id_scratch[i] << " penalty " << penalty << " with filters " ;
1540
- for (auto const &filter : _location_to_labels[id_scratch[i]])
1541
- {
1542
- out << filter << " " ;
1543
- }
1544
- out << std::endl;
1545
- out.close ();
1546
- }
1547
- }
1548
-
1549
- // Replace id_scratch with filtered IDs
1550
- id_scratch = std::move (filtered_ids);
1551
- }
1552
- else if (!use_filter)
1553
- {
1554
- // No filtering - fill res_vec with zeros
1555
- res_vec.resize (id_scratch.size (), 0 .0f );
1556
- // Pre-filter check for non-search invocations
1557
- if (use_filter && !search_invocation)
1558
- {
1559
- if (detect_common_filters (id, search_invocation, filter_labels) < min_inter_size)
1560
- continue ;
1561
- }
1562
-
1563
- // Collect all valid IDs first for batch processing
1564
- id_scratch.push_back (id);
1565
- }
1566
-
1567
1424
// BATCH PROCESSING: Process all Jaccard similarities at once
1568
1425
if (use_filter && !id_scratch.empty ())
1569
1426
{
@@ -1619,11 +1476,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1619
1476
out.close ();
1620
1477
}
1621
1478
1622
- // Mark nodes visited
1623
- /* for (auto id : id_scratch)
1624
- {
1625
- } */
1626
-
1479
+ // Mark nodes visited and compute distances
1627
1480
assert (dist_scratch.capacity () >= id_scratch.size ());
1628
1481
compute_dists (id_scratch, dist_scratch);
1629
1482
cmps += static_cast <uint32_t >(id_scratch.size ()); // Count distance comparisons
@@ -1632,7 +1485,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1632
1485
// Insert <id, dist> pairs into the pool of candidates
1633
1486
for (size_t m = 0 ; m < id_scratch.size (); ++m)
1634
1487
{
1635
- // Apply normalization to vector distance before combining with filter score
1488
+ // Apply normalization to vector distance before combining with filter score
1636
1489
float normalized_distance = g_normalization_config.normalize_distance (dist_scratch[m]);
1637
1490
best_L_nodes.insert (Neighbor (id_scratch[m], normalized_distance + w_m * res_vec[m]));
1638
1491
}
0 commit comments