Skip to content

Commit 2d679a0

Browse files
author
Ananya Sutradhar
committed
removed mem allocations inside iterate_to_fixed_points
1 parent 1ffbd8b commit 2d679a0

File tree

2 files changed

+62
-23
lines changed

2 files changed

+62
-23
lines changed

apps/search_memory_index.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
336336
}
337337

338338
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
339+
340+
if(test_id == 0) {
341+
std::cout << "[PERF] Search complete. Press enter to continue...";
342+
std::cin.get();
343+
}
339344

340345
double displayed_qps = query_num / diff.count();
341346

src/index.cpp

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ namespace diskann
4747
int64_t curr_query = -1;
4848
double curr_intersection_time = 0.0;
4949
double curr_jaccard_time = 0.0;
50+
uint32_t curr_init_cmps = 0; // Distance comparisons for initial candidates in current query
51+
uint32_t curr_expansion_cmps = 0; // Distance comparisons for neighbor expansion in current query
5052
uint32_t penalty_scale = 10;
5153
float w_m = 1.0f;
5254
uint32_t num_sp = 2;
@@ -1128,7 +1130,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11281130
assert(id_scratch.size() == 0);
11291131
T *aligned_query = scratch->aligned_query();
11301132
uint32_t hops = 0;
1131-
uint32_t cmps = 0;
1133+
uint32_t init_cmps = 0; // Distance comparisons for initial candidates
1134+
uint32_t expansion_cmps = 0; // Distance comparisons for neighbor expansion
11321135

11331136
// Pre-size scratch vectors for batch operations to avoid reallocations
11341137
size_t estimated_batch_size = std::max(init_ids.size(), (size_t)(Lsize * 2));
@@ -1229,14 +1232,21 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
12291232
curr_jaccard_time += jaccard_diff.count();
12301233

12311234
// Convert similarities to penalties
1232-
init_penalties.reserve(jaccard_similarities.size());
1235+
init_penalties.clear();
1236+
// Ensure capacity without reallocating - use assert to catch capacity issues early
1237+
assert(init_penalties.capacity() >= jaccard_similarities.size());
12331238
for (float sim : jaccard_similarities) {
12341239
init_penalties.push_back(1.0f - sim);
12351240
}
12361241
}
12371242
else
12381243
{
1239-
init_penalties.resize(valid_init_candidates.size(), 0.0f);
1244+
init_penalties.clear();
1245+
// Ensure capacity without reallocating
1246+
assert(init_penalties.capacity() >= valid_init_candidates.size());
1247+
for (size_t i = 0; i < valid_init_candidates.size(); ++i) {
1248+
init_penalties.push_back(0.0f);
1249+
}
12401250
}
12411251

12421252
// STEP 3: Process all candidates with their pre-computed penalties
@@ -1267,7 +1277,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
12671277
}
12681278

12691279
_pq_data_store->get_distance(aligned_query, ids, 1, distances, scratch);
1270-
cmps++; // Count this distance comparison
1280+
init_cmps++; // Count this as initial candidate distance comparison
12711281

12721282
// Apply normalization to vector distance
12731283
float normalized_distance = g_normalization_config.normalize_distance(distances[0]);
@@ -1456,22 +1466,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
14561466
std::chrono::duration<double> jaccard_diff = std::chrono::high_resolution_clock::now() - jaccard_start;
14571467
curr_jaccard_time += jaccard_diff.count();
14581468

1459-
// Convert similarities to penalties and apply filtering
1460-
res_vec.reserve(id_scratch.size());
1461-
std::vector<uint32_t> &filtered_ids = scratch->filtered_ids_scratch();
1462-
filtered_ids.clear();
1463-
filtered_ids.reserve(id_scratch.size());
1469+
// Convert similarities to penalties - zero allocations
1470+
res_vec.clear();
1471+
// Ensure capacity without reallocating - should be pre-allocated by scratch space
1472+
assert(res_vec.capacity() >= jaccard_similarities.size());
14641473

1465-
for (size_t i = 0; i < id_scratch.size(); ++i)
1474+
// Simple conversion from similarities to penalties
1475+
for (size_t i = 0; i < jaccard_similarities.size(); ++i)
14661476
{
14671477
float penalty = 1.0f - jaccard_similarities[i];
1468-
1469-
// Optional: Filter out points with very high penalty during search
1470-
if (search_invocation && penalty > 0.95f) // 95% penalty threshold
1471-
continue;
1472-
14731478
res_vec.push_back(penalty);
1474-
filtered_ids.push_back(id_scratch[i]);
14751479

14761480
if (print_qstats)
14771481
{
@@ -1485,14 +1489,15 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
14851489
out.close();
14861490
}
14871491
}
1488-
1489-
// Replace id_scratch with filtered IDs
1490-
id_scratch = std::move(filtered_ids);
14911492
}
14921493
else if (!use_filter)
14931494
{
1494-
// No filtering - fill res_vec with zeros
1495-
res_vec.resize(id_scratch.size(), 0.0f);
1495+
// No filtering - fill res_vec with zeros, no allocations
1496+
res_vec.clear();
1497+
assert(res_vec.capacity() >= id_scratch.size());
1498+
for (size_t i = 0; i < id_scratch.size(); ++i) {
1499+
res_vec.push_back(0.0f);
1500+
}
14961501
}
14971502
}
14981503

@@ -1506,7 +1511,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
15061511
// Mark nodes visited and compute distances
15071512
assert(dist_scratch.capacity() >= id_scratch.size());
15081513
compute_dists(id_scratch, dist_scratch);
1509-
cmps += static_cast<uint32_t>(id_scratch.size()); // Count distance comparisons
1514+
expansion_cmps += static_cast<uint32_t>(id_scratch.size()); // Count neighbor expansion distance comparisons
15101515
assert(res_vec.size() == id_scratch.size());
15111516

15121517
// Insert <id, dist> pairs into the pool of candidates
@@ -1517,7 +1522,13 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
15171522
best_L_nodes.insert(Neighbor(id_scratch[m], normalized_distance + w_m * res_vec[m]));
15181523
}
15191524
}
1520-
return std::make_pair(hops, cmps);
1525+
1526+
// Update global variables for query statistics
1527+
curr_init_cmps = init_cmps;
1528+
curr_expansion_cmps = expansion_cmps;
1529+
1530+
uint32_t total_cmps = init_cmps + expansion_cmps;
1531+
return std::make_pair(hops, total_cmps);
15211532
}
15221533

15231534
template <typename T, typename TagT, typename LabelT>
@@ -2681,6 +2692,17 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
26812692
diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl;
26822693
}
26832694

2695+
// Print distance comparison statistics if enabled
2696+
if (print_qstats)
2697+
{
2698+
std::ofstream out("query_stats.txt", std::ios_base::app);
2699+
out << "Distance comparisons - Initial candidates: " << curr_init_cmps
2700+
<< ", Neighbor expansion: " << curr_expansion_cmps
2701+
<< ", Total: " << (curr_init_cmps + curr_expansion_cmps) << std::endl;
2702+
out << std::endl;
2703+
out.close();
2704+
}
2705+
26842706
return retval;
26852707
}
26862708

@@ -3138,6 +3160,18 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
31383160
// }
31393161

31403162
//std::cout << "[DEBUG] Inside search_with_filters: num_graphs = " << num_graphs << std::endl;
3163+
3164+
// Print distance comparison statistics if enabled
3165+
if (print_qstats && local_print)
3166+
{
3167+
std::ofstream out("query_stats.txt", std::ios_base::app);
3168+
out << "Distance comparisons - Initial candidates: " << curr_init_cmps
3169+
<< ", Neighbor expansion: " << curr_expansion_cmps
3170+
<< ", Total: " << (curr_init_cmps + curr_expansion_cmps) << std::endl;
3171+
out << std::endl;
3172+
out.close();
3173+
}
3174+
31413175
return retval;
31423176
}
31433177

0 commit comments

Comments
 (0)