@@ -2405,19 +2405,19 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
2405
2405
template <typename T, typename TagT, typename LabelT>
2406
2406
std::pair<uint32_t , uint32_t > Index<T, TagT, LabelT>::_search_with_filters (const DataType &query,
2407
2407
const std::string &raw_label, const size_t K,
2408
- const uint32_t L, std::any &indices,
2408
+ const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
2409
2409
float *distances)
2410
2410
{
2411
2411
auto converted_label = this ->get_converted_label (raw_label);
2412
2412
if (typeid (uint64_t *) == indices.type ())
2413
2413
{
2414
2414
auto ptr = std::any_cast<uint64_t *>(indices);
2415
- return this ->search_with_filters (std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
2415
+ return this ->search_with_filters (std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
2416
2416
}
2417
2417
else if (typeid (uint32_t *) == indices.type ())
2418
2418
{
2419
2419
auto ptr = std::any_cast<uint32_t *>(indices);
2420
- return this ->search_with_filters (std::any_cast<const T *>(query), converted_label, K, L, ptr, distances);
2420
+ return this ->search_with_filters (std::any_cast<const T *>(query), converted_label, K, L, maxLperSeller, ptr, distances);
2421
2421
}
2422
2422
else
2423
2423
{
@@ -2428,7 +2428,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::_search_with_filters(const
2428
2428
template <typename T, typename TagT, typename LabelT>
2429
2429
template <typename IdType>
2430
2430
std::pair<uint32_t , uint32_t > Index<T, TagT, LabelT>::search_with_filters (const T *query, const LabelT &filter_label,
2431
- const size_t K, const uint32_t L,
2431
+ const size_t K, const uint32_t L, const uint32_t maxLperSeller,
2432
2432
IdType *indices, float *distances)
2433
2433
{
2434
2434
if (K > (uint64_t )L)
@@ -2471,25 +2471,31 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
2471
2471
filter_vec.emplace_back (filter_label);
2472
2472
2473
2473
_data_store->preprocess_query (query, scratch);
2474
- auto retval = iterate_to_fixed_point (scratch, L, init_ids, true , filter_vec, true );
2474
+ auto retval = iterate_to_fixed_point (scratch, L, init_ids, true , filter_vec, true , maxLperSeller );
2475
2475
2476
- auto best_L_nodes = scratch->best_l_nodes ();
2476
+ NeighborPriorityQueueBase* best_L_nodes;
2477
+ if (!_diverse_index) {
2478
+ best_L_nodes = &(scratch->best_l_nodes ());
2479
+ }
2480
+ else {
2481
+ best_L_nodes = &(scratch->best_diverse_nodes ());
2482
+ }
2477
2483
2478
2484
size_t pos = 0 ;
2479
- for (size_t i = 0 ; i < best_L_nodes. size (); ++i)
2485
+ for (size_t i = 0 ; i < best_L_nodes-> size (); ++i)
2480
2486
{
2481
- if (best_L_nodes[i].id < _max_points)
2487
+ if ((* best_L_nodes) [i].id < _max_points)
2482
2488
{
2483
- indices[pos] = (IdType)best_L_nodes[i].id ;
2489
+ indices[pos] = (IdType)(* best_L_nodes) [i].id ;
2484
2490
2485
2491
if (distances != nullptr )
2486
2492
{
2487
2493
#ifdef EXEC_ENV_OLS
2488
2494
// DLVS expects negative distances
2489
- distances[pos] = best_L_nodes[i].distance ;
2495
+ distances[pos] = (* best_L_nodes) [i].distance ;
2490
2496
#else
2491
- distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance
2492
- : best_L_nodes[i].distance ;
2497
+ distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * (* best_L_nodes) [i].distance
2498
+ : (* best_L_nodes) [i].distance ;
2493
2499
#endif
2494
2500
}
2495
2501
pos++;
@@ -3737,41 +3743,41 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
3737
3743
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);
3738
3744
3739
3745
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint64_t , uint32_t >::search_with_filters<
3740
- uint64_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3746
+ uint64_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3741
3747
float *distances);
3742
3748
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint64_t , uint32_t >::search_with_filters<
3743
- uint32_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3749
+ uint32_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3744
3750
float *distances);
3745
3751
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint64_t , uint32_t >::search_with_filters<
3746
- uint64_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3752
+ uint64_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3747
3753
float *distances);
3748
3754
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint64_t , uint32_t >::search_with_filters<
3749
- uint32_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3755
+ uint32_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3750
3756
float *distances);
3751
3757
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint64_t , uint32_t >::search_with_filters<
3752
- uint64_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3758
+ uint64_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3753
3759
float *distances);
3754
3760
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint64_t , uint32_t >::search_with_filters<
3755
- uint32_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3761
+ uint32_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3756
3762
float *distances);
3757
3763
// TagT==uint32_t
3758
3764
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint32_t , uint32_t >::search_with_filters<
3759
- uint64_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3765
+ uint64_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3760
3766
float *distances);
3761
3767
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint32_t , uint32_t >::search_with_filters<
3762
- uint32_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3768
+ uint32_t >(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3763
3769
float *distances);
3764
3770
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint32_t , uint32_t >::search_with_filters<
3765
- uint64_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3771
+ uint64_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3766
3772
float *distances);
3767
3773
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint32_t , uint32_t >::search_with_filters<
3768
- uint32_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3774
+ uint32_t >(const uint8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3769
3775
float *distances);
3770
3776
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint32_t , uint32_t >::search_with_filters<
3771
- uint64_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3777
+ uint64_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3772
3778
float *distances);
3773
3779
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint32_t , uint32_t >::search_with_filters<
3774
- uint32_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3780
+ uint32_t >(const int8_t *query, const uint32_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3775
3781
float *distances);
3776
3782
3777
3783
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint64_t , uint16_t >::search<uint64_t >(
@@ -3801,40 +3807,40 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> Index<int8_t, uint32_t,
3801
3807
const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller);
3802
3808
3803
3809
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint64_t , uint16_t >::search_with_filters<
3804
- uint64_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3810
+ uint64_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3805
3811
float *distances);
3806
3812
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint64_t , uint16_t >::search_with_filters<
3807
- uint32_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3813
+ uint32_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3808
3814
float *distances);
3809
3815
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint64_t , uint16_t >::search_with_filters<
3810
- uint64_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3816
+ uint64_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3811
3817
float *distances);
3812
3818
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint64_t , uint16_t >::search_with_filters<
3813
- uint32_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3819
+ uint32_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3814
3820
float *distances);
3815
3821
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint64_t , uint16_t >::search_with_filters<
3816
- uint64_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3822
+ uint64_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3817
3823
float *distances);
3818
3824
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint64_t , uint16_t >::search_with_filters<
3819
- uint32_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3825
+ uint32_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3820
3826
float *distances);
3821
3827
// TagT==uint32_t
3822
3828
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint32_t , uint16_t >::search_with_filters<
3823
- uint64_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3829
+ uint64_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3824
3830
float *distances);
3825
3831
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<float , uint32_t , uint16_t >::search_with_filters<
3826
- uint32_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3832
+ uint32_t >(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3827
3833
float *distances);
3828
3834
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint32_t , uint16_t >::search_with_filters<
3829
- uint64_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3835
+ uint64_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3830
3836
float *distances);
3831
3837
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<uint8_t , uint32_t , uint16_t >::search_with_filters<
3832
- uint32_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3838
+ uint32_t >(const uint8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3833
3839
float *distances);
3834
3840
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint32_t , uint16_t >::search_with_filters<
3835
- uint64_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices,
3841
+ uint64_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint64_t *indices,
3836
3842
float *distances);
3837
3843
template DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > Index<int8_t , uint32_t , uint16_t >::search_with_filters<
3838
- uint32_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint32_t *indices,
3844
+ uint32_t >(const int8_t *query, const uint16_t &filter_label, const size_t K, const uint32_t L, const uint32_t maxLperSeller, uint32_t *indices,
3839
3845
float *distances);
3840
3846
} // namespace diskann
0 commit comments