Skip to content

Commit a6cd177

Browse files
author
Ananya Sutradhar
committed
moved temporary vectors to sratch space
1 parent 9318e9a commit a6cd177

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

include/scratch.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,36 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
122122
return _tmp_intersection;
123123
}
124124

125+
// Scratch space accessors for Jaccard similarity calculations
126+
inline std::vector<float> &jaccard_similarities_scratch()
127+
{
128+
return _jaccard_similarities_scratch;
129+
}
130+
131+
inline std::vector<float> &init_penalties_scratch()
132+
{
133+
return _init_penalties_scratch;
134+
}
135+
136+
inline std::vector<uint32_t> &valid_init_candidates_scratch()
137+
{
138+
return _valid_init_candidates_scratch;
139+
}
140+
141+
inline std::vector<uint32_t> &filtered_ids_scratch()
142+
{
143+
return _filtered_ids_scratch;
144+
}
145+
146+
// Method to pre-size scratch vectors for batch operations
147+
inline void resize_jaccard_scratch_for_batch(size_t expected_batch_size)
148+
{
149+
_jaccard_similarities_scratch.reserve(expected_batch_size);
150+
_init_penalties_scratch.reserve(expected_batch_size);
151+
_valid_init_candidates_scratch.reserve(expected_batch_size);
152+
_filtered_ids_scratch.reserve(expected_batch_size);
153+
}
154+
125155

126156
private:
127157
uint32_t _L;
@@ -169,6 +199,12 @@ template <typename T> class InMemQueryScratch : public AbstractScratch<T>
169199
std::vector<uint32_t> _closest_clusters;
170200
std::vector<float> _cluster_distances;
171201
float *_aligned_query_float;
202+
203+
// Scratch space for Jaccard similarity batch calculations
204+
std::vector<float> _jaccard_similarities_scratch;
205+
std::vector<float> _init_penalties_scratch;
206+
std::vector<uint32_t> _valid_init_candidates_scratch;
207+
std::vector<uint32_t> _filtered_ids_scratch;
172208
};
173209

174210
//

src/index.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,10 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11301130
uint32_t hops = 0;
11311131
uint32_t cmps = 0;
11321132

1133+
// Pre-size scratch vectors for batch operations to avoid reallocations
1134+
size_t estimated_batch_size = std::max(init_ids.size(), (size_t)(Lsize * 2));
1135+
scratch->resize_jaccard_scratch_for_batch(estimated_batch_size);
1136+
11331137
float *pq_dists = nullptr;
11341138

11351139
if (print_qstats)
@@ -1166,7 +1170,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
11661170
uint32_t candidates_filtered = 0;
11671171

11681172
// STEP 1: Collect all valid initial candidates first
1169-
std::vector<uint32_t> valid_init_candidates;
1173+
std::vector<uint32_t> &valid_init_candidates = scratch->valid_init_candidates_scratch();
1174+
valid_init_candidates.clear();
11701175
valid_init_candidates.reserve(init_ids.size());
11711176

11721177
for (auto id : init_ids)
@@ -1202,11 +1207,13 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
12021207
}
12031208

12041209
// STEP 2: BATCH PROCESS all initial candidates' filter scores at once
1205-
std::vector<float> init_penalties;
1210+
std::vector<float> &init_penalties = scratch->init_penalties_scratch();
1211+
init_penalties.clear();
12061212
if (use_filter && !valid_init_candidates.empty())
12071213
{
12081214
auto jaccard_start = std::chrono::high_resolution_clock::now();
1209-
std::vector<float> jaccard_similarities;
1215+
std::vector<float> &jaccard_similarities = scratch->jaccard_similarities_scratch();
1216+
jaccard_similarities.clear();
12101217
calculate_jaccard_similarity_batch(valid_init_candidates, filter_labels, jaccard_similarities);
12111218
std::chrono::duration<double> jaccard_diff = std::chrono::high_resolution_clock::now() - jaccard_start;
12121219
curr_jaccard_time += jaccard_diff.count();
@@ -1425,14 +1432,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
14251432
if (use_filter && !id_scratch.empty())
14261433
{
14271434
auto jaccard_start = std::chrono::high_resolution_clock::now();
1428-
std::vector<float> jaccard_similarities;
1435+
std::vector<float> &jaccard_similarities = scratch->jaccard_similarities_scratch();
1436+
jaccard_similarities.clear();
14291437
calculate_jaccard_similarity_batch(id_scratch, filter_labels, jaccard_similarities);
14301438
std::chrono::duration<double> jaccard_diff = std::chrono::high_resolution_clock::now() - jaccard_start;
14311439
curr_jaccard_time += jaccard_diff.count();
14321440

14331441
// Convert similarities to penalties and apply filtering
14341442
res_vec.reserve(id_scratch.size());
1435-
std::vector<uint32_t> filtered_ids;
1443+
std::vector<uint32_t> &filtered_ids = scratch->filtered_ids_scratch();
1444+
filtered_ids.clear();
14361445
filtered_ids.reserve(id_scratch.size());
14371446

14381447
for (size_t i = 0; i < id_scratch.size(); ++i)

src/scratch.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ InMemQueryScratch<T>::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l,
4242
_inserted_into_pool_bs = new boost::dynamic_bitset<>();
4343
_id_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R * _R));
4444
_dist_scratch.reserve((size_t)std::ceil(1.5 * defaults::GRAPH_SLACK_FACTOR * _R * _R));
45+
46+
// Initialize scratch space for Jaccard similarity calculations
47+
size_t max_neighbor_batch = (size_t)std::ceil(defaults::GRAPH_SLACK_FACTOR * _R); // Single level neighbors
48+
size_t max_init_batch = 1000; // Conservative estimate for initial candidates
49+
size_t max_batch_size = std::max(max_neighbor_batch, max_init_batch);
50+
51+
_jaccard_similarities_scratch.reserve(max_batch_size);
52+
_init_penalties_scratch.reserve(max_batch_size);
53+
_valid_init_candidates_scratch.reserve(max_batch_size);
54+
_filtered_ids_scratch.reserve(max_batch_size);
4555

4656
resize_for_new_L(std::max(search_l, indexing_l));
4757
}
@@ -62,6 +72,12 @@ template <typename T> void InMemQueryScratch<T>::clear()
6272
_expanded_nodes_set.clear();
6373
_expanded_nghrs_vec.clear();
6474
_occlude_list_output.clear();
75+
76+
// Clear Jaccard similarity scratch vectors
77+
_jaccard_similarities_scratch.clear();
78+
_init_penalties_scratch.clear();
79+
_valid_init_candidates_scratch.clear();
80+
_filtered_ids_scratch.clear();
6581
}
6682

6783
template <typename T> void InMemQueryScratch<T>::resize_for_new_L(uint32_t new_l)

0 commit comments

Comments
 (0)