Skip to content

Commit a52995f

Browse files
committed
Merge branch 'jegao/LabelHotFix' of https://github.com/microsoft/DiskANN into jegao/FormatMemoryIndex
2 parents 9011711 + 4e99954 commit a52995f

File tree

7 files changed

+174
-38
lines changed

7 files changed

+174
-38
lines changed

include/disk_utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ DISKANN_DLLEXPORT int build_disk_index(
9999
const std::string &label_file = std::string(""), // default is empty string for no label_file
100100
const std::string &universal_label = "", const uint32_t filter_threshold = 0,
101101
const uint32_t Lf = 0,
102-
const char* reorderDataFilePath = nullptr); // default is empty string for no universal label
102+
const char* reorderDataFilePath = nullptr,
103+
const char* sellerFilePath = nullptr,
104+
uint32_t num_diverse_build = 1); // default is empty string for no universal label
103105

104106
template <typename T>
105107
DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file,

include/neighbor.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ class NeighborVector : public NeighborVectorBase
127127
class NeighborExtendColorVector : public NeighborVectorBase
128128
{
129129
public:
130-
const static uint32_t s_vector_size_limited = 1000000;
130+
const static uint32_t s_vector_size_limited = 50 * 64;
131+
const static uint32_t s_map_size_reserve = 100;
131132

132133
NeighborExtendColorVector(const std::vector<uint32_t>& location_to_seller)
133134
: _location_to_seller(location_to_seller)
@@ -148,7 +149,7 @@ class NeighborExtendColorVector : public NeighborVectorBase
148149
else
149150
{
150151
auto color_to_info = std::make_unique<ColorInfoMap>();
151-
color_to_info->reserve(uniqueSellerCount);
152+
color_to_info->reserve(s_map_size_reserve);
152153
_color_to_info = std::move(color_to_info);
153154
}
154155
}
@@ -315,7 +316,7 @@ class NeighborExtendColorVector : public NeighborVectorBase
315316
{
316317
std::cout << "unique seller " << uniqueSellerCount << "ColorInfoMap created" << std::endl;
317318
auto color_to_info = std::make_unique<ColorInfoMap>();
318-
color_to_info->reserve(uniqueSellerCount);
319+
color_to_info->reserve(s_map_size_reserve);
319320
_color_to_info = std::move(color_to_info);
320321
}
321322
}

include/pq_flash_index.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
4646
const char* pivots_filepath, const char* compressed_filepath,
4747
const char* labels_filepath, const char* labels_to_medoids_filepath,
4848
const char* labels_map_filepath, const char* unv_label_filepath,
49+
const char* seller_filepath,
4950
bool load_bitmask_label = false);
5051
#endif
5152

@@ -56,27 +57,31 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
5657

5758
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
5859
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
60+
uint32_t maxLperSeller = 0,
5961
const bool use_reorder_data = false,
6062
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
6163
QueryStats *stats = nullptr);
6264

6365
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
6466
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
6567
const bool use_filter, const LabelT &filter_label,
68+
uint32_t maxLperSeller = 0,
6669
const bool use_reorder_data = false,
6770
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
6871
QueryStats *stats = nullptr);
6972

7073
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
7174
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
72-
const uint32_t io_limit, const bool use_reorder_data = false,
75+
const uint32_t io_limit,
76+
uint32_t maxLperSeller = 0, const bool use_reorder_data = false,
7377
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
7478
QueryStats *stats = nullptr);
7579

7680
DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search,
7781
uint64_t *res_ids, float *res_dists, const uint64_t beam_width,
7882
const bool use_filter, const LabelT &filter_label,
79-
const uint32_t io_limit, const bool use_reorder_data = false,
83+
const uint32_t io_limit, uint32_t maxLperSeller = 0,
84+
const bool use_reorder_data = false,
8085
std::function<float(const std::uint8_t*, size_t)> rerank_fn = nullptr,
8186
QueryStats *stats = nullptr);
8287

@@ -121,6 +126,9 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
121126
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char>& infile);
122127
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
123128
uint32_t &num_total_labels);
129+
130+
DISKANN_DLLEXPORT void parse_seller_file(const std::string& label_file, size_t& num_pts_labels);
131+
124132
void reset_stream_for_reading(std::basic_istream<char> &infile);
125133

126134
// sector # on disk where node_id is present with in the graph part
@@ -235,6 +243,10 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
235243
tsl::robin_map<uint32_t, std::vector<uint32_t>> _real_to_dummy_map;
236244
std::unordered_map<std::string, LabelT> _label_map;
237245

246+
bool _diverse_index = false;
247+
std::vector<uint32_t> _location_to_seller;
248+
uint32_t _num_unique_sellers = 0;
249+
238250
TableStats _table_stats;
239251

240252
#ifdef EXEC_ENV_OLS

include/scratch.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,12 @@ template <typename T> class SSDQueryScratch : public AbstractScratch<T>
165165

166166
tsl::robin_set<size_t> visited;
167167
NeighborPriorityQueue retset;
168+
NeighborPriorityQueueExtendColor _best_diverse_nodes;
168169
std::vector<Neighbor> full_retset;
169170
// bitmask buffer in searching time
170171
std::vector<std::uint64_t> _query_label_bitmask;
171172

172-
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t bitmask_size = 0);
173+
SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector<uint32_t>& location_to_sellers, size_t bitmask_size = 0);
173174
~SSDQueryScratch();
174175

175176
void reset();
@@ -186,7 +187,7 @@ template <typename T> class SSDThreadData
186187
SSDQueryScratch<T> scratch;
187188
IOContext ctx;
188189

189-
SSDThreadData(size_t aligned_dim, size_t visited_reserve);
190+
SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector<uint32_t>& location_to_sellers);
190191
void clear();
191192
};
192193

src/disk_utils.cpp

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
630630
std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq,
631631
uint32_t num_threads, bool use_filters, const std::string &label_file,
632632
const std::string &labels_to_medoids_file, const std::string &universal_label,
633-
const uint32_t Lf, uint32_t universal_label_num = 0)
633+
const uint32_t Lf, uint32_t universal_label_num = 0,
634+
const char* seller_file_path = nullptr,
635+
uint32_t num_diverse_build = 1)
634636
{
635637
size_t base_num, base_dim;
636638
diskann::get_bin_metadata(base_file, base_num, base_dim);
@@ -643,10 +645,18 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr
643645
diskann::cout << "Full index fits in RAM budget, should consume at most "
644646
<< full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl;
645647

648+
bool is_diverse_index = false;
649+
if (seller_file_path != nullptr && !std::string(seller_file_path).empty())
650+
{
651+
is_diverse_index = true;
652+
}
646653
diskann::IndexWriteParameters paras = diskann::IndexWriteParametersBuilder(L, R)
647654
.with_filter_list_size(Lf)
648655
.with_saturate_graph(!use_filters)
649656
.with_num_threads(num_threads)
657+
.with_diverse_index(is_diverse_index)
658+
.with_seller_file(seller_file_path)
659+
.with_num_diverse_build(num_diverse_build)
650660
.build();
651661
using TagT = uint32_t;
652662
diskann::Index<T, TagT, LabelT> _index(compareMetric, base_dim, base_num,
@@ -1106,7 +1116,9 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11061116
diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters,
11071117
const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold,
11081118
const uint32_t Lf,
1109-
const char* reorderDataFilePath)
1119+
const char* reorderDataFilePath,
1120+
const char* sellerFilePath,
1121+
uint32_t num_diverse_build)
11101122
{
11111123
std::stringstream parser;
11121124
parser << std::string(indexBuildParameters);
@@ -1194,6 +1206,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
11941206
std::string mem_univ_label_file = mem_index_path + "_universal_label.txt";
11951207
std::string disk_univ_label_file = disk_index_path + "_universal_label.txt";
11961208
std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt";
1209+
std::string disk_seller_file = disk_index_path + "_sellers.txt";
11971210
std::string dummy_remap_file = disk_index_path + "_dummy_remap.txt"; // remap will be used if we break-up points of
11981211
// high label-density to create copies
11991212

@@ -1333,7 +1346,8 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13331346
diskann::build_merged_vamana_index<T, LabelT>(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
13341347
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
13351348
build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use,
1336-
labels_to_medoids_path, universal_label, Lf, universal_label_id);
1349+
labels_to_medoids_path, universal_label, Lf, universal_label_id,
1350+
sellerFilePath, num_diverse_build);
13371351
diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl;
13381352

13391353
timer.reset();
@@ -1377,6 +1391,14 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const
13771391
std::remove(augmented_labels_file.c_str());
13781392
std::remove(labels_file_to_use.c_str());
13791393
}
1394+
1395+
std::string seller_mem_file = std::string(mem_index_path) + "_sellers.txt";
1396+
if (file_exists(seller_mem_file))
1397+
{
1398+
copy_file(seller_mem_file, disk_seller_file);
1399+
std::remove(seller_mem_file.c_str());
1400+
}
1401+
13801402
if (created_temp_file_for_processed_data)
13811403
std::remove(prepped_base.c_str());
13821404
std::remove(mem_index_path.c_str());
@@ -1448,23 +1470,26 @@ template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint32_t>(const char *da
14481470
const std::string &label_file,
14491471
const std::string &universal_label,
14501472
const uint32_t filter_threshold, const uint32_t Lf,
1451-
const char* reorderDataFilePath);
1473+
const char* reorderDataFilePath, const char* sellerFilePath,
1474+
uint32_t num_diverse_build);
14521475
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint32_t>(const char *dataFilePath, const char *indexFilePath,
14531476
const char *indexBuildParameters,
14541477
diskann::Metric compareMetric, bool use_opq,
14551478
const std::string &codebook_prefix, bool use_filters,
14561479
const std::string &label_file,
14571480
const std::string &universal_label,
14581481
const uint32_t filter_threshold, const uint32_t Lf,
1459-
const char* reorderDataFilePath);
1482+
const char* reorderDataFilePath, const char* sellerFilePath,
1483+
uint32_t num_diverse_build);
14601484
template DISKANN_DLLEXPORT int build_disk_index<float, uint32_t>(const char *dataFilePath, const char *indexFilePath,
14611485
const char *indexBuildParameters,
14621486
diskann::Metric compareMetric, bool use_opq,
14631487
const std::string &codebook_prefix, bool use_filters,
14641488
const std::string &label_file,
14651489
const std::string &universal_label,
14661490
const uint32_t filter_threshold, const uint32_t Lf,
1467-
const char* reorderDataFilePath);
1491+
const char* reorderDataFilePath, const char* sellerFilePath,
1492+
uint32_t num_diverse_build);
14681493
// LabelT = uint16
14691494
template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint16_t>(const char *dataFilePath, const char *indexFilePath,
14701495
const char *indexBuildParameters,
@@ -1473,23 +1498,26 @@ template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint16_t>(const char *da
14731498
const std::string &label_file,
14741499
const std::string &universal_label,
14751500
const uint32_t filter_threshold, const uint32_t Lf,
1476-
const char* reorderDataFilePath);
1501+
const char* reorderDataFilePath, const char* sellerFilePath,
1502+
uint32_t num_diverse_build);
14771503
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint16_t>(const char *dataFilePath, const char *indexFilePath,
14781504
const char *indexBuildParameters,
14791505
diskann::Metric compareMetric, bool use_opq,
14801506
const std::string &codebook_prefix, bool use_filters,
14811507
const std::string &label_file,
14821508
const std::string &universal_label,
14831509
const uint32_t filter_threshold, const uint32_t Lf,
1484-
const char* reorderDataFilePath);
1510+
const char* reorderDataFilePath, const char* sellerFilePath,
1511+
uint32_t num_diverse_build);
14851512
template DISKANN_DLLEXPORT int build_disk_index<float, uint16_t>(const char *dataFilePath, const char *indexFilePath,
14861513
const char *indexBuildParameters,
14871514
diskann::Metric compareMetric, bool use_opq,
14881515
const std::string &codebook_prefix, bool use_filters,
14891516
const std::string &label_file,
14901517
const std::string &universal_label,
14911518
const uint32_t filter_threshold, const uint32_t Lf,
1492-
const char* reorderDataFilePath);
1519+
const char* reorderDataFilePath, const char* sellerFilePath,
1520+
uint32_t num_diverse_build);
14931521

14941522
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint32_t>(
14951523
std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate,

0 commit comments

Comments
 (0)