Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class AbstractIndex
// IndexType is either uint32_t or uint64_t
template <typename IndexType>
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices,
float *distances);

// insert points with labels, labels should be present for filtered index
Expand Down Expand Up @@ -122,7 +123,7 @@ class AbstractIndex
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
std::any& indices, float* distances = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
const size_t K, const uint32_t L, std::any &indices,
const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Filter support search
template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IndexType *indices, float *distances);

// Will fail if tag already in the index or if tag=0.
Expand Down Expand Up @@ -218,7 +218,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
std::any &indices, float *distances = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::string &filter_label_raw, const size_t K,
const uint32_t L, std::any &indices,
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances) override;

virtual int _insert_point(const DataType &data_point, const TagType tag) override;
Expand Down
8 changes: 4 additions & 4 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K,

template <typename IndexType>
std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices,
float *distances)
{
auto any_indices = std::any(indices);
return _search_with_filters(query, raw_label, K, L, any_indices, distances);
return _search_with_filters(query, raw_label, K, L, maxLperSeller, any_indices, distances);
}

template <typename data_type>
Expand Down Expand Up @@ -173,11 +173,11 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<i
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search<float, uint32_t>(
Expand Down
78 changes: 42 additions & 36 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2405,19 +2405,19 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
template <typename T, typename TagT, typename LabelT>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const DataType &query,
const std::string &raw_label, const size_t K,
const uint32_t L, std::any &indices,
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
float *distances)
{
auto converted_label = this->get_converted_label(raw_label);
if (typeid(uint64_t *) == indices.type())
{
auto ptr = std::any_cast<uint64_t *>(indices);
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
}
else if (typeid(uint32_t *) == indices.type())
{
auto ptr = std::any_cast<uint32_t *>(indices);
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
}
else
{
Expand All @@ -2428,7 +2428,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
template <typename T, typename TagT, typename LabelT>
template <typename IdType>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const T *query, const LabelT &filter_label,
const size_t K, const uint32_t L,
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
IdType *indices, float *distances)
{
if (K > (uint64_t)L)
Expand Down Expand Up @@ -2471,25 +2471,31 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
filter_vec.emplace_back(filter_label);

_data_store->preprocess_query(query, scratch);
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller);

auto best_L_nodes = scratch->best_l_nodes();
NeighborPriorityQueueBase* best_L_nodes;
if (!_diverse_index) {
best_L_nodes = &(scratch->best_l_nodes());
}
else {
best_L_nodes = &(scratch->best_diverse_nodes());
}

size_t pos = 0;
for (size_t i = 0; i < best_L_nodes.size(); ++i)
for (size_t i = 0; i < best_L_nodes->size(); ++i)
{
if (best_L_nodes[i].id < _max_points)
if ((*best_L_nodes)[i].id < _max_points)
{
indices[pos] = (IdType)best_L_nodes[i].id;
indices[pos] = (IdType)(*best_L_nodes)[i].id;

if (distances != nullptr)
{
#ifdef EXEC_ENV_OLS
// DLVS expects negative distances
distances[pos] = best_L_nodes[i].distance;
distances[pos] = (*best_L_nodes)[i].distance;
#else
distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance
: best_L_nodes[i].distance;
distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (*best_L_nodes)[i].distance
: (*best_L_nodes)[i].distance;
#endif
}
pos++;
Expand Down Expand Up @@ -3737,41 +3743,41 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
// TagT==uint32_t
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
Expand Down Expand Up @@ -3801,40 +3807,40 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
// TagT==uint32_t
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
float *distances);
} // namespace diskann
Loading