@@ -47,6 +47,8 @@ namespace diskann
47
47
int64_t curr_query = -1 ;
48
48
double curr_intersection_time = 0.0 ;
49
49
double curr_jaccard_time = 0.0 ;
50
+ uint32_t curr_init_cmps = 0 ; // Distance comparisons for initial candidates in current query
51
+ uint32_t curr_expansion_cmps = 0 ; // Distance comparisons for neighbor expansion in current query
50
52
uint32_t penalty_scale = 10 ;
51
53
float w_m = 1 .0f ;
52
54
uint32_t num_sp = 2 ;
@@ -1128,7 +1130,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1128
1130
assert (id_scratch.size () == 0 );
1129
1131
T *aligned_query = scratch->aligned_query ();
1130
1132
uint32_t hops = 0 ;
1131
- uint32_t cmps = 0 ;
1133
+ uint32_t init_cmps = 0 ; // Distance comparisons for initial candidates
1134
+ uint32_t expansion_cmps = 0 ; // Distance comparisons for neighbor expansion
1132
1135
1133
1136
// Pre-size scratch vectors for batch operations to avoid reallocations
1134
1137
size_t estimated_batch_size = std::max (init_ids.size (), (size_t )(Lsize * 2 ));
@@ -1229,14 +1232,21 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1229
1232
curr_jaccard_time += jaccard_diff.count ();
1230
1233
1231
1234
// Convert similarities to penalties
1232
- init_penalties.reserve (jaccard_similarities.size ());
1235
+ init_penalties.clear ();
1236
+ // Ensure capacity without reallocating - use assert to catch capacity issues early
1237
+ assert (init_penalties.capacity () >= jaccard_similarities.size ());
1233
1238
for (float sim : jaccard_similarities) {
1234
1239
init_penalties.push_back (1 .0f - sim);
1235
1240
}
1236
1241
}
1237
1242
else
1238
1243
{
1239
- init_penalties.resize (valid_init_candidates.size (), 0 .0f );
1244
+ init_penalties.clear ();
1245
+ // Ensure capacity without reallocating
1246
+ assert (init_penalties.capacity () >= valid_init_candidates.size ());
1247
+ for (size_t i = 0 ; i < valid_init_candidates.size (); ++i) {
1248
+ init_penalties.push_back (0 .0f );
1249
+ }
1240
1250
}
1241
1251
1242
1252
// STEP 3: Process all candidates with their pre-computed penalties
@@ -1267,7 +1277,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1267
1277
}
1268
1278
1269
1279
_pq_data_store->get_distance (aligned_query, ids, 1 , distances, scratch);
1270
- cmps ++; // Count this distance comparison
1280
+ init_cmps ++; // Count this as initial candidate distance comparison
1271
1281
1272
1282
// Apply normalization to vector distance
1273
1283
float normalized_distance = g_normalization_config.normalize_distance (distances[0 ]);
@@ -1456,22 +1466,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1456
1466
std::chrono::duration<double > jaccard_diff = std::chrono::high_resolution_clock::now () - jaccard_start;
1457
1467
curr_jaccard_time += jaccard_diff.count ();
1458
1468
1459
- // Convert similarities to penalties and apply filtering
1460
- res_vec.reserve (id_scratch.size ());
1461
- std::vector<uint32_t > &filtered_ids = scratch->filtered_ids_scratch ();
1462
- filtered_ids.clear ();
1463
- filtered_ids.reserve (id_scratch.size ());
1469
+ // Convert similarities to penalties - zero allocations
1470
+ res_vec.clear ();
1471
+ // Ensure capacity without reallocating - should be pre-allocated by scratch space
1472
+ assert (res_vec.capacity () >= jaccard_similarities.size ());
1464
1473
1465
- for (size_t i = 0 ; i < id_scratch.size (); ++i)
1474
+ // Simple conversion from similarities to penalties
1475
+ for (size_t i = 0 ; i < jaccard_similarities.size (); ++i)
1466
1476
{
1467
1477
float penalty = 1 .0f - jaccard_similarities[i];
1468
-
1469
- // Optional: Filter out points with very high penalty during search
1470
- if (search_invocation && penalty > 0 .95f ) // 95% penalty threshold
1471
- continue ;
1472
-
1473
1478
res_vec.push_back (penalty);
1474
- filtered_ids.push_back (id_scratch[i]);
1475
1479
1476
1480
if (print_qstats)
1477
1481
{
@@ -1485,14 +1489,15 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1485
1489
out.close ();
1486
1490
}
1487
1491
}
1488
-
1489
- // Replace id_scratch with filtered IDs
1490
- id_scratch = std::move (filtered_ids);
1491
1492
}
1492
1493
else if (!use_filter)
1493
1494
{
1494
- // No filtering - fill res_vec with zeros
1495
- res_vec.resize (id_scratch.size (), 0 .0f );
1495
+ // No filtering - fill res_vec with zeros, no allocations
1496
+ res_vec.clear ();
1497
+ assert (res_vec.capacity () >= id_scratch.size ());
1498
+ for (size_t i = 0 ; i < id_scratch.size (); ++i) {
1499
+ res_vec.push_back (0 .0f );
1500
+ }
1496
1501
}
1497
1502
}
1498
1503
@@ -1506,7 +1511,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1506
1511
// Mark nodes visited and compute distances
1507
1512
assert (dist_scratch.capacity () >= id_scratch.size ());
1508
1513
compute_dists (id_scratch, dist_scratch);
1509
- cmps += static_cast <uint32_t >(id_scratch.size ()); // Count distance comparisons
1514
+ expansion_cmps += static_cast <uint32_t >(id_scratch.size ()); // Count neighbor expansion distance comparisons
1510
1515
assert (res_vec.size () == id_scratch.size ());
1511
1516
1512
1517
// Insert <id, dist> pairs into the pool of candidates
@@ -1517,7 +1522,13 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1517
1522
best_L_nodes.insert (Neighbor (id_scratch[m], normalized_distance + w_m * res_vec[m]));
1518
1523
}
1519
1524
}
1520
- return std::make_pair (hops, cmps);
1525
+
1526
+ // Update global variables for query statistics
1527
+ curr_init_cmps = init_cmps;
1528
+ curr_expansion_cmps = expansion_cmps;
1529
+
1530
+ uint32_t total_cmps = init_cmps + expansion_cmps;
1531
+ return std::make_pair (hops, total_cmps);
1521
1532
}
1522
1533
1523
1534
template <typename T, typename TagT, typename LabelT>
@@ -2681,6 +2692,17 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
2681
2692
diskann::cerr << " Found pos: " << pos << " fewer than K elements " << K << " for query" << std::endl;
2682
2693
}
2683
2694
2695
+ // Print distance comparison statistics if enabled
2696
+ if (print_qstats)
2697
+ {
2698
+ std::ofstream out (" query_stats.txt" , std::ios_base::app);
2699
+ out << " Distance comparisons - Initial candidates: " << curr_init_cmps
2700
+ << " , Neighbor expansion: " << curr_expansion_cmps
2701
+ << " , Total: " << (curr_init_cmps + curr_expansion_cmps) << std::endl;
2702
+ out << std::endl;
2703
+ out.close ();
2704
+ }
2705
+
2684
2706
return retval;
2685
2707
}
2686
2708
@@ -3138,6 +3160,18 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
3138
3160
// }
3139
3161
3140
3162
// std::cout << "[DEBUG] Inside search_with_filters: num_graphs = " << num_graphs << std::endl;
3163
+
3164
+ // Print distance comparison statistics if enabled
3165
+ if (print_qstats && local_print)
3166
+ {
3167
+ std::ofstream out (" query_stats.txt" , std::ios_base::app);
3168
+ out << " Distance comparisons - Initial candidates: " << curr_init_cmps
3169
+ << " , Neighbor expansion: " << curr_expansion_cmps
3170
+ << " , Total: " << (curr_init_cmps + curr_expansion_cmps) << std::endl;
3171
+ out << std::endl;
3172
+ out.close ();
3173
+ }
3174
+
3141
3175
return retval;
3142
3176
}
3143
3177
0 commit comments