@@ -951,25 +951,63 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::brute_force_filters(const
951
951
952
952
template <typename T, typename TagT, typename LabelT>
953
953
inline float Index<T, TagT, LabelT>:: calculate_jaccard_similarity (const std::vector<LabelT> &set1, const std::vector<LabelT> &set2) {
954
- // std::cout << "calculate_jaccard_similarity called" << std::endl;
955
- std::unordered_set<LabelT> intersection, union_set;
956
-
957
- for (const auto &label : set1) {
958
- union_set.insert (label);
959
- }
960
- for (const auto &label : set2) {
961
- if (union_set.find (label) != union_set.end ()) {
962
- intersection.insert (label);
954
+ if (set1.empty ()) return 0 .0f ;
955
+
956
+ size_t intersection_count = 0 ;
957
+
958
+ // For small sets, linear scan is often faster due to cache locality
959
+ // Threshold based on your colleagues' discussion about cache vs complexity
960
+ constexpr size_t LINEAR_SCAN_THRESHOLD = 100 ;
961
+
962
+ if (set1.size () <= LINEAR_SCAN_THRESHOLD && set2.size () <= LINEAR_SCAN_THRESHOLD) {
963
+ // Linear scan approach - cache friendly for small sets
964
+ for (const auto &label : set1) {
965
+ if (std::find (set2.begin (), set2.end (), label) != set2.end ()) {
966
+ intersection_count++;
967
+ }
968
+ }
969
+ } else {
970
+ // Hash table approach for larger sets
971
+ const auto &smaller = set1.size () <= set2.size () ? set1 : set2;
972
+ const auto &larger = set1.size () <= set2.size () ? set2 : set1;
973
+
974
+ std::unordered_set<LabelT> lookup_set (larger.begin (), larger.end ());
975
+
976
+ for (const auto &label : smaller) {
977
+ if (lookup_set.count (label)) {
978
+ intersection_count++;
979
+ }
963
980
}
964
- union_set.insert (label);
965
981
}
982
+
983
+ return static_cast <float >(intersection_count) / static_cast <float >(set1.size ());
984
+ }
966
985
967
- if (union_set.empty ()) {
968
- return 0 .0f ; // Avoid division by zero
986
+ // Overloaded version for multi-filter query labels (vector<vector<LabelT>>)
987
+ // Returns the count of how many filter sets (clauses) have intersection with vector_labels
988
+ template <typename T, typename TagT, typename LabelT>
989
+ inline float Index<T, TagT, LabelT>:: calculate_jaccard_similarity (const std::vector<std::vector<LabelT>> &filter_sets, const std::vector<LabelT> &vector_labels) {
990
+ if (filter_sets.empty ()) return 0 .0f ;
991
+
992
+ size_t matching_clauses = 0 ;
993
+
994
+ // Count how many filter sets (clauses) have ANY intersection with vector_labels
995
+ for (const auto & filter_set : filter_sets) {
996
+ // Check if ANY filter in this clause matches ANY label in vector_labels
997
+ bool clause_satisfied = false ;
998
+ for (const auto & filter : filter_set) {
999
+ if (std::find (vector_labels.begin (), vector_labels.end (), filter) != vector_labels.end ()) {
1000
+ clause_satisfied = true ;
1001
+ break ; // Early exit - this clause is satisfied
1002
+ }
1003
+ }
1004
+ if (clause_satisfied) {
1005
+ matching_clauses++;
1006
+ }
969
1007
}
970
-
971
- // return static_cast<float>(intersection.size()) / static_cast<float>(union_set.size());
972
- return static_cast <float >(intersection. size ()) / static_cast <float >(set1 .size ());
1008
+
1009
+ // Return fraction of clauses that match
1010
+ return static_cast <float >(matching_clauses) / static_cast <float >(filter_sets .size ());
973
1011
}
974
1012
975
1013
@@ -1067,7 +1105,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1067
1105
// if (res > 0) {
1068
1106
// res = 1;
1069
1107
// }
1070
- res = 1 - calculate_jaccard_similarity (filter_labels[ 0 ] , _location_to_labels[id]);
1108
+ res = 1 - calculate_jaccard_similarity (filter_labels, _location_to_labels[id]);
1071
1109
if (print_qstats)
1072
1110
{
1073
1111
std::ofstream out (" query_stats.txt" , std::ios_base::app);
@@ -1095,7 +1133,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1095
1133
continue ;
1096
1134
}
1097
1135
else {
1098
- res = 1 - calculate_jaccard_similarity (filter_labels[ 0 ] , _location_to_labels[id]);
1136
+ res = 1 - calculate_jaccard_similarity (filter_labels, _location_to_labels[id]);
1099
1137
if (print_qstats)
1100
1138
{
1101
1139
std::ofstream out (" query_stats.txt" , std::ios_base::app);
@@ -1298,7 +1336,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1298
1336
// penalty = res * penalty_scale;
1299
1337
// i
1300
1338
1301
- res = 1 - calculate_jaccard_similarity (filter_labels[ 0 ] , _location_to_labels[id]);
1339
+ res = 1 - calculate_jaccard_similarity (filter_labels, _location_to_labels[id]);
1302
1340
1303
1341
1304
1342
if (print_qstats)
@@ -1319,7 +1357,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
1319
1357
if (detect_common_filters (id, search_invocation, filter_labels) < min_inter_size)
1320
1358
continue ;
1321
1359
else {
1322
- res = 1 - calculate_jaccard_similarity (filter_labels[ 0 ] , _location_to_labels[id]);
1360
+ res = 1 - calculate_jaccard_similarity (filter_labels, _location_to_labels[id]);
1323
1361
}
1324
1362
}
1325
1363
}
0 commit comments