Skip to content

Commit 9318e9a

Browse files
author
Ananya Sutradhar
committed
fixed bug in batch process
1 parent bcafc1d commit 9318e9a

File tree

1 file changed

+4
-151
lines changed

1 file changed

+4
-151
lines changed

src/index.cpp

Lines changed: 4 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,15 +1118,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11181118
InMemQueryScratch<T> *scratch, const uint32_t Lsize, const std::vector<uint32_t> &init_ids, bool use_filter,
11191119
const std::vector<std::vector<LabelT>> &filter_labels, bool search_invocation)
11201120
{
1121-
1122-
/* for (auto &x : filter_labels) {
1123-
std::cout<<"(";
1124-
for (auto &y : x) {
1125-
std::cout<<y<<"|";
1126-
}
1127-
std::cout<<")&";
1128-
}*/
1129-
11301121
std::vector<Neighbor> &expanded_nodes = scratch->pool();
11311122
NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes();
11321123
best_L_nodes.reserve(Lsize);
@@ -1137,7 +1128,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11371128
assert(id_scratch.size() == 0);
11381129
T *aligned_query = scratch->aligned_query();
11391130
uint32_t hops = 0;
1140-
uint32_t cmps = 0;
1131+
uint32_t cmps = 0;
11411132

11421133
float *pq_dists = nullptr;
11431134

@@ -1170,20 +1161,14 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11701161
_pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch);
11711162
};
11721163

1173-
// Initialize the candidate pool with starting points - BATCH PROCESS ALL INITIAL CANDIDATES
1174-
// Initialize the candidate pool with starting points - BATCH PROCESS ALL INITIAL CANDIDATES
1175-
// std::cout << "DEBUG: Starting with " << init_ids.size() << " initial candidates" << std::endl;
1164+
// Initialize the candidate pool with starting points
11761165
uint32_t candidates_added = 0;
11771166
uint32_t candidates_filtered = 0;
11781167

11791168
// STEP 1: Collect all valid initial candidates first
11801169
std::vector<uint32_t> valid_init_candidates;
11811170
valid_init_candidates.reserve(init_ids.size());
11821171

1183-
// STEP 1: Collect all valid initial candidates first
1184-
std::vector<uint32_t> valid_init_candidates;
1185-
valid_init_candidates.reserve(init_ids.size());
1186-
11871172
for (auto id : init_ids)
11881173
{
11891174
if (id >= _max_points + _num_frozen_pts)
@@ -1237,65 +1222,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
12371222
init_penalties.resize(valid_init_candidates.size(), 0.0f);
12381223
}
12391224

1240-
// STEP 3: Process all candidates with their pre-computed penalties
1241-
for (size_t i = 0; i < valid_init_candidates.size(); ++i)
1242-
{
1243-
uint32_t id = valid_init_candidates[i];
1244-
float res = init_penalties[i];
1245-
1246-
candidates_added++;
1247-
if (fast_iterate)
1248-
{
1249-
inserted_into_pool_bs.add(id);
1250-
}
1251-
else
1252-
{
1253-
inserted_into_pool_rs.insert(id);
1254-
}
1255-
// Pre-filter for non-search invocations
1256-
if (use_filter && !search_invocation)
1257-
{
1258-
uint32_t common_count = detect_common_filters(id, search_invocation, filter_labels);
1259-
if (common_count < min_inter_size)
1260-
{
1261-
candidates_filtered++;
1262-
if (print_qstats)
1263-
{
1264-
std::ofstream out("query_stats.txt", std::ios_base::app);
1265-
out << "FILTERED OUT: id " << id << " has only " << common_count << " common filters (need " << min_inter_size << ")" << std::endl;
1266-
out.close();
1267-
}
1268-
continue;
1269-
}
1270-
}
1271-
1272-
if (is_not_visited(id))
1273-
{
1274-
valid_init_candidates.push_back(id);
1275-
}
1276-
}
1277-
1278-
// STEP 2: BATCH PROCESS all initial candidates' filter scores at once
1279-
std::vector<float> init_penalties;
1280-
if (use_filter && !valid_init_candidates.empty())
1281-
{
1282-
auto jaccard_start = std::chrono::high_resolution_clock::now();
1283-
std::vector<float> jaccard_similarities;
1284-
calculate_jaccard_similarity_batch(valid_init_candidates, filter_labels, jaccard_similarities);
1285-
std::chrono::duration<double> jaccard_diff = std::chrono::high_resolution_clock::now() - jaccard_start;
1286-
curr_jaccard_time += jaccard_diff.count();
1287-
1288-
// Convert similarities to penalties
1289-
init_penalties.reserve(jaccard_similarities.size());
1290-
for (float sim : jaccard_similarities) {
1291-
init_penalties.push_back(1.0f - sim);
1292-
}
1293-
}
1294-
else
1295-
{
1296-
init_penalties.resize(valid_init_candidates.size(), 0.0f);
1297-
}
1298-
12991225
// STEP 3: Process all candidates with their pre-computed penalties
13001226
for (size_t i = 0; i < valid_init_candidates.size(); ++i)
13011227
{
@@ -1326,19 +1252,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
13261252
Neighbor nn = Neighbor(id, distance);
13271253
best_L_nodes.insert(nn);
13281254

1329-
if (print_qstats)
1330-
{
1331-
std::ofstream out("query_stats.txt", std::ios_base::app);
1332-
out << "BATCH INIT: id " << id << " has penalty " << res << " and filters ";
1333-
for (auto const &filter : _location_to_labels[id])
1334-
{
1335-
out << filter << " ";
1336-
}
1337-
out << std::endl;
1338-
out.close();
1339-
Neighbor nn = Neighbor(id, distance);
1340-
best_L_nodes.insert(nn);
1341-
13421255
if (print_qstats)
13431256
{
13441257
std::ofstream out("query_stats.txt", std::ios_base::app);
@@ -1508,62 +1421,6 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
15081421
id_scratch.push_back(id);
15091422
}
15101423

1511-
// BATCH PROCESSING: Process all Jaccard similarities at once
1512-
if (use_filter && !id_scratch.empty())
1513-
{
1514-
auto jaccard_start = std::chrono::high_resolution_clock::now();
1515-
std::vector<float> jaccard_similarities;
1516-
calculate_jaccard_similarity_batch(id_scratch, filter_labels, jaccard_similarities);
1517-
std::chrono::duration<double> jaccard_diff = std::chrono::high_resolution_clock::now() - jaccard_start;
1518-
curr_jaccard_time += jaccard_diff.count();
1519-
1520-
// Convert similarities to penalties and apply filtering
1521-
res_vec.reserve(id_scratch.size());
1522-
std::vector<uint32_t> filtered_ids;
1523-
filtered_ids.reserve(id_scratch.size());
1524-
1525-
for (size_t i = 0; i < id_scratch.size(); ++i)
1526-
{
1527-
float penalty = 1.0f - jaccard_similarities[i];
1528-
1529-
// Optional: Filter out points with very high penalty during search
1530-
if (search_invocation && penalty > 0.95f) // 95% penalty threshold
1531-
continue;
1532-
1533-
res_vec.push_back(penalty);
1534-
filtered_ids.push_back(id_scratch[i]);
1535-
1536-
if (print_qstats)
1537-
{
1538-
std::ofstream out("query_stats.txt", std::ios_base::app);
1539-
out << "BATCH processed nbr " << id_scratch[i] << " penalty " << penalty << " with filters ";
1540-
for (auto const &filter : _location_to_labels[id_scratch[i]])
1541-
{
1542-
out << filter << " ";
1543-
}
1544-
out << std::endl;
1545-
out.close();
1546-
}
1547-
}
1548-
1549-
// Replace id_scratch with filtered IDs
1550-
id_scratch = std::move(filtered_ids);
1551-
}
1552-
else if (!use_filter)
1553-
{
1554-
// No filtering - fill res_vec with zeros
1555-
res_vec.resize(id_scratch.size(), 0.0f);
1556-
// Pre-filter check for non-search invocations
1557-
if (use_filter && !search_invocation)
1558-
{
1559-
if (detect_common_filters(id, search_invocation, filter_labels) < min_inter_size)
1560-
continue;
1561-
}
1562-
1563-
// Collect all valid IDs first for batch processing
1564-
id_scratch.push_back(id);
1565-
}
1566-
15671424
// BATCH PROCESSING: Process all Jaccard similarities at once
15681425
if (use_filter && !id_scratch.empty())
15691426
{
@@ -1619,11 +1476,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
16191476
out.close();
16201477
}
16211478

1622-
// Mark nodes visited
1623-
/* for (auto id : id_scratch)
1624-
{
1625-
} */
1626-
1479+
// Mark nodes visited and compute distances
16271480
assert(dist_scratch.capacity() >= id_scratch.size());
16281481
compute_dists(id_scratch, dist_scratch);
16291482
cmps += static_cast<uint32_t>(id_scratch.size()); // Count distance comparisons
@@ -1632,7 +1485,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
16321485
// Insert <id, dist> pairs into the pool of candidates
16331486
for (size_t m = 0; m < id_scratch.size(); ++m)
16341487
{
1635-
// Apply normalization to vector distance before combining with filter score
1488+
// Apply normalization to vector distance before combining with filter score
16361489
float normalized_distance = g_normalization_config.normalize_distance(dist_scratch[m]);
16371490
best_L_nodes.insert(Neighbor(id_scratch[m], normalized_distance + w_m * res_vec[m]));
16381491
}

0 commit comments

Comments
 (0)