diff --git a/include/abstract_index.h b/include/abstract_index.h index dcd5c853f..1bb94fa06 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -81,7 +81,8 @@ class AbstractIndex // IndexType is either uint32_t or uint64_t template std::pair search_with_filters(const DataType &query, const std::string &raw_label, - const size_t K, const uint32_t L, IndexType *indices, + const size_t K, const uint32_t L, const uint32_t maxLperSeller, + IndexType *indices, float *distances); // insert points with labels, labels should be present for filtered index @@ -122,7 +123,7 @@ class AbstractIndex virtual std::pair _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any& indices, float* distances = nullptr) = 0; virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, - const size_t K, const uint32_t L, std::any &indices, + const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector &labels) = 0; virtual int _insert_point(const DataType &data_point, const TagType tag) = 0; diff --git a/include/index.h b/include/index.h index ad9a1f107..33059b43b 100644 --- a/include/index.h +++ b/include/index.h @@ -153,7 +153,7 @@ template clas // Filter support search template DISKANN_DLLEXPORT std::pair search_with_filters(const T *query, const LabelT &filter_label, - const size_t K, const uint32_t L, + const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, float *distances); // Will fail if tag already in the index or if tag=0. @@ -218,7 +218,7 @@ template clas std::any &indices, float *distances = nullptr) override; virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label_raw, const size_t K, - const uint32_t L, std::any &indices, + const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) override; virtual int _insert_point(const DataType &data_point, const TagType tag) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 0afa94550..a2c85e08e 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -44,11 +44,11 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, template std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, - const size_t K, const uint32_t L, IndexType *indices, + const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices, float *distances) { auto any_indices = std::any(indices); - return _search_with_filters(query, raw_label, K, L, any_indices, distances); + return _search_with_filters(query, raw_label, K, L, maxLperSeller, any_indices, distances); } template @@ -173,11 +173,11 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( - const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices, + const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( diff --git a/src/index.cpp b/src/index.cpp index aca7b0d3b..7a5d7130f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2405,19 +2405,19 @@ std::pair Index::search(const T *query, con template std::pair Index::_search_with_filters(const DataType &query, const std::string &raw_label, const size_t K, - const uint32_t L, std::any &indices, + const uint32_t L, const uint32_t maxLperSeller, std::any &indices, float *distances) { auto converted_label = this->get_converted_label(raw_label); if (typeid(uint64_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, maxLperSeller, ptr, distances); } else if (typeid(uint32_t *) == indices.type()) { auto ptr = std::any_cast(indices); - return this->search_with_filters(std::any_cast(query), converted_label, K, L, ptr, distances); + return this->search_with_filters(std::any_cast(query), converted_label, K, L, maxLperSeller, ptr, distances); } else { @@ -2428,7 +2428,7 @@ std::pair Index::_search_with_filters(const template template std::pair Index::search_with_filters(const T *query, const LabelT &filter_label, - const size_t K, const uint32_t L, + const size_t K, const uint32_t L, const uint32_t maxLperSeller, IdType *indices, float *distances) { if (K > (uint64_t)L) @@ -2471,25 +2471,31 @@ std::pair Index::search_with_filters(const filter_vec.emplace_back(filter_label); _data_store->preprocess_query(query, scratch); - auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller); - auto best_L_nodes = scratch->best_l_nodes(); + NeighborPriorityQueueBase* best_L_nodes; + if (!_diverse_index) { + best_L_nodes = &(scratch->best_l_nodes()); + } + else { + best_L_nodes = &(scratch->best_diverse_nodes()); + } size_t pos = 0; - for (size_t i = 0; i < best_L_nodes.size(); ++i) + for (size_t i = 0; i < best_L_nodes->size(); ++i) { - if (best_L_nodes[i].id < _max_points) + if ((*best_L_nodes)[i].id < _max_points) { - indices[pos] = (IdType)best_L_nodes[i].id; + indices[pos] = (IdType)(*best_L_nodes)[i].id; if (distances != nullptr) { #ifdef EXEC_ENV_OLS // DLVS expects negative distances - distances[pos] = best_L_nodes[i].distance; + distances[pos] = (*best_L_nodes)[i].distance; #else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance - : best_L_nodes[i].distance; + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (*best_L_nodes)[i].distance + : (*best_L_nodes)[i].distance; #endif } pos++; @@ -3737,41 +3743,41 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search( @@ -3801,40 +3807,40 @@ template DISKANN_DLLEXPORT std::pair Index Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, + uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< - uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices, + uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices, float *distances); } // namespace diskann