Skip to content

Commit 389e404

Browse files
t-sutradhara_microsoftsdananya
authored andcommitted
modified jaccard similarity
1 parent 175e97c commit 389e404

File tree

4 files changed

+120
-38
lines changed

4 files changed

+120
-38
lines changed

apps/search_memory_index.cpp

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
142142
<< std::setw(20) << "Brute Recall"
143143
<< std::setw(22) << "Graph avg cmps"
144144
<< std::setw(22) << "Graph Latency(mus)"
145-
<< std::setw(20) << "Graph Recall"
145+
<< std::setw(20) << "Graph Recall"
146+
<< std::setw(18) << "Filter Eval(mus)"
147+
<< std::setw(18) << "Penalty Det(mus)"
148+
<< std::setw(18) << "Core Algo(mus)"
146149
<< std::endl;
147150

148151
table_width += 4 + 4 + 8 + 18 + 20 + 20 + 20 + 20 + 10 + 22 + 20 + 22 + 20 + 22 + 22;
@@ -408,26 +411,49 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
408411
}
409412
else
410413
{
411-
std::cout << std::setw(4) << L << std::setw(4) << recall_at << std::setw(8) << displayed_qps << std::setw(15) << avg_cmps
412-
<< std::setw(20) << (float)mean_latency << std::setw(15)
413-
<< std::setw(20) << (float)latency_999 << std::setw(15)
414-
<< std::setw(20) << (float)latency_99 << std::setw(15)
415-
<< std::setw(20) << (float)latency_95 << std::setw(15)
416-
<< (float)recalls[0] << std::setw(20)
417-
<< (float)(brute_dist_cmp[test_id] * 1.0) / (num_brutes * 1.0) << std::setw(22)
418-
<< (float)(brute_lat[test_id] * 1.0) / (num_brutes * 1.0) << std::setw(20)
419-
<< (float)(brute_recalls[test_id] * 100.0) / (num_brutes * recall_at * 1.0) << std::setw(20)
420-
<< (float)(graph_lat[test_id] * 1.0) / (num_graphs * 1.0) << std::setw(20)
421-
<< (float)(graph_recalls[test_id] * 100.0) / (num_graphs * recall_at * 1.0) << " " << (1000000*time_to_detect_penalty) / query_num << "\t" << (1000000*time_to_get_valid) / query_num
422-
// << std::setw(20) << (float)(brute_lat[test_id]*1.0) << std::setw(20) <<
423-
// (float)(brute_recalls[test_id]*100.0)
424-
// << std::setw(20) << (float)(graph_lat[test_id]*1.0) << std::setw(20) <<
425-
// (float)(graph_recalls[test_id]*100.0)
414+
// Calculate timing breakdowns (convert to microseconds)
415+
float filter_eval_time_us = (float)(time_to_get_valid * 1000000.0) / (float)query_num;
416+
float penalty_detection_time_us = (float)(time_to_detect_penalty * 1000000.0) / (float)query_num;
417+
float core_algo_time_us = mean_latency - filter_eval_time_us - penalty_detection_time_us;
418+
419+
std::cout << std::setw(4) << L << std::setw(4) << recall_at << std::setw(8) << displayed_qps << std::setw(18) << avg_cmps
420+
<< std::setw(20) << (float)mean_latency
421+
<< std::setw(20) << (float)latency_999
422+
<< std::setw(20) << (float)latency_99
423+
<< std::setw(20) << (float)latency_95
424+
<< std::setw(10) << (float)recalls[0]
425+
<< std::setw(22) << (float)(brute_dist_cmp[test_id] * 1.0) / (num_brutes * 1.0)
426+
<< std::setw(22) << (float)(brute_lat[test_id] * 1.0) / (num_brutes * 1.0)
427+
<< std::setw(20) << (float)(brute_recalls[test_id] * 100.0) / (num_brutes * recall_at * 1.0)
428+
<< std::setw(22) << (float)(graph_dist_cmp[test_id] * 1.0) / (num_graphs * 1.0)
429+
<< std::setw(22) << (float)(graph_lat[test_id] * 1.0) / (num_graphs * 1.0)
430+
<< std::setw(20) << (float)(graph_recalls[test_id] * 100.0) / (num_graphs * recall_at * 1.0)
431+
<< std::setw(18) << filter_eval_time_us
432+
<< std::setw(18) << penalty_detection_time_us
433+
<< std::setw(18) << core_algo_time_us
426434
<< std::endl;
427435
}
428436
}
429437
std::cout << "num_graphs " << num_graphs << std::endl;
430438
std::cout << "num_brutes " << num_brutes << std::endl;
439+
440+
// Print detailed timing breakdown summary
441+
if (filtered_search) {
442+
std::cout << "\n=== TIMING BREAKDOWN ANALYSIS ===" << std::endl;
443+
std::cout << "Total queries: " << query_num << std::endl;
444+
std::cout << "Filter evaluation time: " << (time_to_get_valid * 1000000.0) / query_num << " μs/query" << std::endl;
445+
std::cout << "Penalty detection time: " << (time_to_detect_penalty * 1000000.0) / query_num << " μs/query" << std::endl;
446+
std::cout << "Filter intersection time: " << (time_to_intersect * 1000000.0) / query_num << " μs/query" << std::endl;
447+
std::cout << "Filter check & compare time: " << (time_to_filter_check_and_compare * 1000000.0) / query_num << " μs/query" << std::endl;
448+
449+
double total_filter_overhead = (time_to_get_valid + time_to_detect_penalty + time_to_intersect + time_to_filter_check_and_compare) * 1000000.0 / query_num;
450+
std::cout << "Total filter overhead: " << total_filter_overhead << " μs/query" << std::endl;
451+
452+
std::cout << "Breakdown percentage:" << std::endl;
453+
std::cout << " Graph searches: " << (100.0 * num_graphs) / query_num << "%" << std::endl;
454+
std::cout << " Brute force searches: " << (100.0 * num_brutes) / query_num << "%" << std::endl;
455+
std::cout << "=================================" << std::endl;
456+
}
431457

432458
std::cout << "Done searching. Now saving results " << std::endl;
433459
uint64_t test_id = 0;

include/index.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
121121
const std::vector<std::vector<LabelT>> &incoming_labels);
122122

123123
DISKANN_DLLEXPORT inline float calculate_jaccard_similarity(const std::vector<LabelT> &set1, const std::vector<LabelT> &set2);
124+
125+
// Overloaded version for multi-filter query labels (vector<vector<LabelT>>)
126+
DISKANN_DLLEXPORT inline float calculate_jaccard_similarity(const std::vector<std::vector<LabelT>> &filter_sets, const std::vector<LabelT> &vector_labels);
124127

125128
// Batch build from a file. Optionally pass tags vector.
126129
DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load,

scripts/ml_ilp/ilp.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def main():
130130
parser.add_argument('--method', choices=['ratio', 'lp', 'pulp'], default='ratio')
131131
parser.add_argument('--eps', type=float, default=1e-4)
132132
parser.add_argument('--plot', action='store_true')
133+
parser.add_argument('--norm_factors', help='Normalization factors file (scale, shift)', default=None)
133134
args = parser.parse_args()
134135

135136
# Read the ground truth file
@@ -164,9 +165,23 @@ def main():
164165
print(f"Distances shape: {distances.shape}")
165166
print(f"Matches shape: {matches.shape}")
166167

167-
print(f"Distances: {distances[0][:5]}")
168-
# distances_scaled = distances / distances.max()
169-
# print(f"Scaled distances: {distances_scaled[0][:5]}")
168+
print(f"Original distances: {distances[0][:5]}")
169+
170+
# Apply normalization if provided
171+
if args.norm_factors:
172+
with open(args.norm_factors, 'r') as f:
173+
line = f.readline().strip()
174+
scale, shift = map(float, line.split())
175+
print(f"Applying normalization: scale={scale}, shift={shift}")
176+
distances_normalized = (distances + shift) * scale
177+
print(f"Normalized distances: {distances_normalized[0][:5]}")
178+
distances = distances_normalized
179+
else:
180+
# Fallback: simple max normalization
181+
distances_max = distances.max()
182+
print(f"Max distance: {distances_max}")
183+
distances = distances / distances_max
184+
print(f"Max-normalized distances: {distances[0][:5]}")
170185

171186
if args.method == 'ratio':
172187
w_d, w_m, total_pairs, _ = direct_ratio_method(distances, filter_matches, args.eps)

src/index.cpp

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -951,25 +951,63 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::brute_force_filters(const
951951

952952
template <typename T, typename TagT, typename LabelT>
953953
inline float Index<T, TagT, LabelT>:: calculate_jaccard_similarity(const std::vector<LabelT> &set1, const std::vector<LabelT> &set2) {
954-
// std::cout << "calculate_jaccard_similarity called" << std::endl;
955-
std::unordered_set<LabelT> intersection, union_set;
956-
957-
for (const auto &label : set1) {
958-
union_set.insert(label);
959-
}
960-
for (const auto &label : set2) {
961-
if (union_set.find(label) != union_set.end()) {
962-
intersection.insert(label);
954+
if (set1.empty()) return 0.0f;
955+
956+
size_t intersection_count = 0;
957+
958+
// For small sets, linear scan is often faster due to cache locality
959+
// Threshold based on your colleagues' discussion about cache vs complexity
960+
constexpr size_t LINEAR_SCAN_THRESHOLD = 100;
961+
962+
if (set1.size() <= LINEAR_SCAN_THRESHOLD && set2.size() <= LINEAR_SCAN_THRESHOLD) {
963+
// Linear scan approach - cache friendly for small sets
964+
for (const auto &label : set1) {
965+
if (std::find(set2.begin(), set2.end(), label) != set2.end()) {
966+
intersection_count++;
967+
}
968+
}
969+
} else {
970+
// Hash table approach for larger sets
971+
const auto &smaller = set1.size() <= set2.size() ? set1 : set2;
972+
const auto &larger = set1.size() <= set2.size() ? set2 : set1;
973+
974+
std::unordered_set<LabelT> lookup_set(larger.begin(), larger.end());
975+
976+
for (const auto &label : smaller) {
977+
if (lookup_set.count(label)) {
978+
intersection_count++;
979+
}
963980
}
964-
union_set.insert(label);
965981
}
982+
983+
return static_cast<float>(intersection_count) / static_cast<float>(set1.size());
984+
}
966985

967-
if (union_set.empty()) {
968-
return 0.0f; // Avoid division by zero
986+
// Overloaded version for multi-filter query labels (vector<vector<LabelT>>)
987+
// Returns the count of how many filter sets (clauses) have intersection with vector_labels
988+
template <typename T, typename TagT, typename LabelT>
989+
inline float Index<T, TagT, LabelT>:: calculate_jaccard_similarity(const std::vector<std::vector<LabelT>> &filter_sets, const std::vector<LabelT> &vector_labels) {
990+
if (filter_sets.empty()) return 0.0f;
991+
992+
size_t matching_clauses = 0;
993+
994+
// Count how many filter sets (clauses) have ANY intersection with vector_labels
995+
for (const auto& filter_set : filter_sets) {
996+
// Check if ANY filter in this clause matches ANY label in vector_labels
997+
bool clause_satisfied = false;
998+
for (const auto& filter : filter_set) {
999+
if (std::find(vector_labels.begin(), vector_labels.end(), filter) != vector_labels.end()) {
1000+
clause_satisfied = true;
1001+
break; // Early exit - this clause is satisfied
1002+
}
1003+
}
1004+
if (clause_satisfied) {
1005+
matching_clauses++;
1006+
}
9691007
}
970-
971-
// return static_cast<float>(intersection.size()) / static_cast<float>(union_set.size());
972-
return static_cast<float>(intersection.size()) / static_cast<float>(set1.size());
1008+
1009+
// Return fraction of clauses that match
1010+
return static_cast<float>(matching_clauses) / static_cast<float>(filter_sets.size());
9731011
}
9741012

9751013

@@ -1067,7 +1105,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
10671105
// if (res > 0) {
10681106
// res = 1;
10691107
// }
1070-
res = 1 - calculate_jaccard_similarity(filter_labels[0], _location_to_labels[id]);
1108+
res = 1 - calculate_jaccard_similarity(filter_labels, _location_to_labels[id]);
10711109
if (print_qstats)
10721110
{
10731111
std::ofstream out("query_stats.txt", std::ios_base::app);
@@ -1095,7 +1133,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
10951133
continue;
10961134
}
10971135
else {
1098-
res = 1 - calculate_jaccard_similarity(filter_labels[0], _location_to_labels[id]);
1136+
res = 1 - calculate_jaccard_similarity(filter_labels, _location_to_labels[id]);
10991137
if (print_qstats)
11001138
{
11011139
std::ofstream out("query_stats.txt", std::ios_base::app);
@@ -1298,7 +1336,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
12981336
// penalty = res * penalty_scale;
12991337
// i
13001338

1301-
res = 1 - calculate_jaccard_similarity(filter_labels[0], _location_to_labels[id]);
1339+
res = 1 - calculate_jaccard_similarity(filter_labels, _location_to_labels[id]);
13021340

13031341

13041342
if (print_qstats)
@@ -1319,7 +1357,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
13191357
if (detect_common_filters(id, search_invocation, filter_labels) < min_inter_size)
13201358
continue;
13211359
else {
1322-
res = 1 - calculate_jaccard_similarity(filter_labels[0], _location_to_labels[id]);
1360+
res = 1 - calculate_jaccard_similarity(filter_labels, _location_to_labels[id]);
13231361
}
13241362
}
13251363
}

0 commit comments

Comments
 (0)