Skip to content

Commit 7cf5574

Browse files
authored
Jegao/diverse with filtered (#664)
* add diverse + filtered interface * fix issue
1 parent 9da186a commit 7cf5574

File tree

4 files changed

+51
-44
lines changed

4 files changed

+51
-44
lines changed

include/abstract_index.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ class AbstractIndex
8181
// IndexType is either uint32_t or uint64_t
8282
template <typename IndexType>
8383
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
84-
const size_t K, const uint32_t L, IndexType *indices,
84+
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
85+
IndexType *indices,
8586
float *distances);
8687

8788
// insert points with labels, labels should be present for filtered index
@@ -122,7 +123,7 @@ class AbstractIndex
122123
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
123124
std::any& indices, float* distances = nullptr) = 0;
124125
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
125-
const size_t K, const uint32_t L, std::any &indices,
126+
const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
126127
float *distances) = 0;
127128
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
128129
virtual int _insert_point(const DataType &data_point, const TagType tag) = 0;

include/index.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
153153
// Filter support search
154154
template <typename IndexType>
155155
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
156-
const size_t K, const uint32_t L,
156+
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
157157
IndexType *indices, float *distances);
158158

159159
// Will fail if tag already in the index or if tag=0.
@@ -218,7 +218,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
218218
std::any &indices, float *distances = nullptr) override;
219219
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
220220
const std::string &filter_label_raw, const size_t K,
221-
const uint32_t L, std::any &indices,
221+
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
222222
float *distances) override;
223223

224224
virtual int _insert_point(const DataType &data_point, const TagType tag) override;

src/abstract_index.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K,
4444

4545
template <typename IndexType>
4646
std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label,
47-
const size_t K, const uint32_t L, IndexType *indices,
47+
const size_t K, const uint32_t L, const uint32_t maxLperSeller, IndexType *indices,
4848
float *distances)
4949
{
5050
auto any_indices = std::any(indices);
51-
return _search_with_filters(query, raw_label, K, L, any_indices, distances);
51+
return _search_with_filters(query, raw_label, K, L, maxLperSeller, any_indices, distances);
5252
}
5353

5454
template <typename data_type>
@@ -173,11 +173,11 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<i
173173
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
174174

175175
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
176-
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
176+
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
177177
float *distances);
178178

179179
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
180-
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
180+
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
181181
float *distances);
182182

183183
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::diverse_search<float, uint32_t>(

src/index.cpp

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,19 +2405,19 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
24052405
template <typename T, typename TagT, typename LabelT>
24062406
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const DataType &query,
24072407
const std::string &raw_label, const size_t K,
2408-
const uint32_t L, std::any &indices,
2408+
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
24092409
float *distances)
24102410
{
24112411
auto converted_label = this->get_converted_label(raw_label);
24122412
if (typeid(uint64_t *) == indices.type())
24132413
{
24142414
auto ptr = std::any_cast<uint64_t *>(indices);
2415-
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
2415+
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
24162416
}
24172417
else if (typeid(uint32_t *) == indices.type())
24182418
{
24192419
auto ptr = std::any_cast<uint32_t *>(indices);
2420-
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
2420+
return this->search_with_filters(std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
24212421
}
24222422
else
24232423
{
@@ -2428,7 +2428,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
24282428
template <typename T, typename TagT, typename LabelT>
24292429
template <typename IdType>
24302430
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const T *query, const LabelT &filter_label,
2431-
const size_t K, const uint32_t L,
2431+
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
24322432
IdType *indices, float *distances)
24332433
{
24342434
if (K > (uint64_t)L)
@@ -2471,25 +2471,31 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
24712471
filter_vec.emplace_back(filter_label);
24722472

24732473
_data_store->preprocess_query(query, scratch);
2474-
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
2474+
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller);
24752475

2476-
auto best_L_nodes = scratch->best_l_nodes();
2476+
NeighborPriorityQueueBase* best_L_nodes;
2477+
if (!_diverse_index) {
2478+
best_L_nodes = &(scratch->best_l_nodes());
2479+
}
2480+
else {
2481+
best_L_nodes = &(scratch->best_diverse_nodes());
2482+
}
24772483

24782484
size_t pos = 0;
2479-
for (size_t i = 0; i < best_L_nodes.size(); ++i)
2485+
for (size_t i = 0; i < best_L_nodes->size(); ++i)
24802486
{
2481-
if (best_L_nodes[i].id < _max_points)
2487+
if ((*best_L_nodes)[i].id < _max_points)
24822488
{
2483-
indices[pos] = (IdType)best_L_nodes[i].id;
2489+
indices[pos] = (IdType)(*best_L_nodes)[i].id;
24842490

24852491
if (distances != nullptr)
24862492
{
24872493
#ifdef EXEC_ENV_OLS
24882494
// DLVS expects negative distances
2489-
distances[pos] = best_L_nodes[i].distance;
2495+
distances[pos] = (*best_L_nodes)[i].distance;
24902496
#else
2491-
distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance
2492-
: best_L_nodes[i].distance;
2497+
distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (*best_L_nodes)[i].distance
2498+
: (*best_L_nodes)[i].distance;
24932499
#endif
24942500
}
24952501
pos++;
@@ -3737,41 +3743,41 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
37373743
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);
37383744

37393745
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
3740-
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3746+
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37413747
float *distances);
37423748
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint32_t>::search_with_filters<
3743-
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3749+
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37443750
float *distances);
37453751
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
3746-
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3752+
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37473753
float *distances);
37483754
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint32_t>::search_with_filters<
3749-
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3755+
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37503756
float *distances);
37513757
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
3752-
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3758+
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37533759
float *distances);
37543760
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint32_t>::search_with_filters<
3755-
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3761+
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37563762
float *distances);
37573763
// TagT==uint32_t
37583764
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
3759-
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3765+
uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37603766
float *distances);
37613767
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint32_t>::search_with_filters<
3762-
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3768+
uint32_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37633769
float *distances);
37643770
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
3765-
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3771+
uint64_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37663772
float *distances);
37673773
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint32_t>::search_with_filters<
3768-
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3774+
uint32_t>(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37693775
float *distances);
37703776
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
3771-
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3777+
uint64_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
37723778
float *distances);
37733779
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint32_t>::search_with_filters<
3774-
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3780+
uint32_t>(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
37753781
float *distances);
37763782

37773783
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search<uint64_t>(
@@ -3801,40 +3807,40 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
38013807
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);
38023808

38033809
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
3804-
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3810+
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38053811
float *distances);
38063812
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint64_t, uint16_t>::search_with_filters<
3807-
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3813+
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38083814
float *distances);
38093815
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
3810-
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3816+
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38113817
float *distances);
38123818
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint64_t, uint16_t>::search_with_filters<
3813-
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3819+
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38143820
float *distances);
38153821
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
3816-
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3822+
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38173823
float *distances);
38183824
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint64_t, uint16_t>::search_with_filters<
3819-
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3825+
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38203826
float *distances);
38213827
// TagT==uint32_t
38223828
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
3823-
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3829+
uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38243830
float *distances);
38253831
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<float, uint32_t, uint16_t>::search_with_filters<
3826-
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3832+
uint32_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38273833
float *distances);
38283834
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
3829-
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3835+
uint64_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38303836
float *distances);
38313837
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<uint8_t, uint32_t, uint16_t>::search_with_filters<
3832-
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3838+
uint32_t>(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38333839
float *distances);
38343840
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
3835-
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3841+
uint64_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
38363842
float *distances);
38373843
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t, uint16_t>::search_with_filters<
3838-
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3844+
uint32_t>(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
38393845
float *distances);
38403846
} // namespace diskann

0 commit comments

Comments
 (0)