Skip to content

Commit 34efc98

Browse files
author
Ananya Sutradhar
committed
prefetching inside get_distance
1 parent 2d679a0 commit 34efc98

File tree

5 files changed

+147
-21
lines changed

5 files changed

+147
-21
lines changed

scripts/analyze_dataset_span_simple.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -190,23 +190,6 @@ def main():
190190
print(f"{'Mean normalization':<25} {suggestions['mean_norm']:<15.6f} {'0.0':<15} {'Simple scaling only'}")
191191
print(f"{'Median normalization':<25} {suggestions['median_norm']:<15.6f} {'0.0':<15} {'Robust to outliers'}")
192192

193-
# Practical recommendations
194-
print("\n" + "="*60)
195-
print("PRACTICAL RECOMMENDATIONS")
196-
print("="*60)
197-
198-
print("For DiskANN filter + distance combination:")
199-
print(f"1. RECOMMENDED: Min-Max normalization")
200-
print(f" - Apply: normalized_dist = (distance - {stats['min']:.6f}) * {suggestions['span_norm']:.6f}")
201-
print(f" - Result: distances in [0, 1] range, same as filter similarities")
202-
203-
print(f"\n2. ALTERNATIVE: Robust normalization (if outliers present)")
204-
print(f" - Apply: normalized_dist = (distance - {stats['q25']:.6f}) * {suggestions['robust_norm']:.6f}")
205-
print(f" - Result: most distances in [0, 1], outliers may exceed 1")
206-
207-
print(f"\n3. SIMPLE: Mean normalization (if centering not needed)")
208-
print(f" - Apply: normalized_dist = distance * {suggestions['mean_norm']:.6f}")
209-
print(f" - Result: average distance becomes 1.0")
210193

211194
# Always save simple normalization factors to text file
212195
norm_config_file = args.norm_factors if args.norm_factors else 'normalization_factors.txt'

src/distance.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,31 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
120120
{
121121
#ifdef _WINDOWS
122122
#ifdef USE_AVX2
123+
// Prefetch the start of both vectors
124+
_mm_prefetch((const char*)a, _MM_HINT_T0);
125+
_mm_prefetch((const char*)b, _MM_HINT_T0);
126+
123127
__m256 r = _mm256_setzero_ps();
124128
char *pX = (char *)a, *pY = (char *)b;
129+
const char *original_pX = pX;
130+
const char *original_pY = pY;
131+
uint32_t prefetch_offset = 64; // Prefetch 64 bytes ahead
132+
125133
while (size >= 32)
126134
{
135+
// Prefetch ahead for better cache performance
136+
if (size > prefetch_offset)
137+
{
138+
_mm_prefetch(original_pX + prefetch_offset, _MM_HINT_T0);
139+
_mm_prefetch(original_pY + prefetch_offset, _MM_HINT_T0);
140+
}
141+
127142
__m256i r1 = _mm256_subs_epi8(_mm256_loadu_si256((__m256i *)pX), _mm256_loadu_si256((__m256i *)pY));
128143
r = _mm256_add_ps(r, _mm256_mul_epi8(r1, r1));
129144
pX += 32;
130145
pY += 32;
131146
size -= 32;
147+
prefetch_offset += 32;
132148
}
133149
while (size > 0)
134150
{
@@ -141,19 +157,39 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
141157
r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r);
142158
return r.m256_f32[0] + r.m256_f32[4];
143159
#else
160+
// Prefetch the start of both vectors for non-AVX2 fallback
161+
_mm_prefetch((const char*)a, _MM_HINT_T0);
162+
_mm_prefetch((const char*)b, _MM_HINT_T0);
163+
144164
int32_t result = 0;
145165
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
146166
for (int32_t i = 0; i < (int32_t)size; i++)
147167
{
168+
// Prefetch ahead every 64 bytes (64 int8_t values)
169+
if (i % 64 == 0 && i + 64 < (int32_t)size)
170+
{
171+
_mm_prefetch((const char*)(a + i + 64), _MM_HINT_T0);
172+
_mm_prefetch((const char*)(b + i + 64), _MM_HINT_T0);
173+
}
148174
result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
149175
}
150176
return (float)result;
151177
#endif
152178
#else
179+
// Prefetch the start of both vectors for Linux version
180+
_mm_prefetch((const char*)a, _MM_HINT_T0);
181+
_mm_prefetch((const char*)b, _MM_HINT_T0);
182+
153183
int32_t result = 0;
154184
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
155185
for (int32_t i = 0; i < (int32_t)size; i++)
156186
{
187+
// Prefetch ahead every 64 bytes (64 int8_t values)
188+
if (i % 64 == 0 && i + 64 < (int32_t)size)
189+
{
190+
_mm_prefetch((const char*)(a + i + 64), _MM_HINT_T0);
191+
_mm_prefetch((const char*)(b + i + 64), _MM_HINT_T0);
192+
}
157193
result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
158194
}
159195
return (float)result;
@@ -162,12 +198,22 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
162198

163199
float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size) const
164200
{
201+
// Prefetch the start of both vectors
202+
_mm_prefetch((const char*)a, _MM_HINT_T0);
203+
_mm_prefetch((const char*)b, _MM_HINT_T0);
204+
165205
uint32_t result = 0;
166206
#ifndef _WINDOWS
167207
#pragma omp simd reduction(+ : result) aligned(a, b : 8)
168208
#endif
169209
for (int32_t i = 0; i < (int32_t)size; i++)
170210
{
211+
// Prefetch ahead every 64 bytes (64 uint8_t values)
212+
if (i % 64 == 0 && i + 64 < (int32_t)size)
213+
{
214+
_mm_prefetch((const char*)(a + i + 64), _MM_HINT_T0);
215+
_mm_prefetch((const char*)(b + i + 64), _MM_HINT_T0);
216+
}
171217
result += ((int32_t)((int16_t)a[i] - (int16_t)b[i])) * ((int32_t)((int16_t)a[i] - (int16_t)b[i]));
172218
}
173219
return (float)result;
@@ -209,11 +255,21 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) co
209255
// horizontal add sum
210256
result = _mm256_reduce_add_ps(sum);
211257
#else
258+
// Prefetch the start of both vectors for non-AVX2 fallback
259+
_mm_prefetch((const char*)a, _MM_HINT_T0);
260+
_mm_prefetch((const char*)b, _MM_HINT_T0);
261+
212262
#ifndef _WINDOWS
213263
#pragma omp simd reduction(+ : result) aligned(a, b : 32)
214264
#endif
215265
for (int32_t i = 0; i < (int32_t)size; i++)
216266
{
267+
// Prefetch ahead every 16 floats (64 bytes)
268+
if (i % 16 == 0 && i + 16 < (int32_t)size)
269+
{
270+
_mm_prefetch((const char*)(a + i + 16), _MM_HINT_T0);
271+
_mm_prefetch((const char*)(b + i + 16), _MM_HINT_T0);
272+
}
217273
result += (a[i] - b[i]) * (a[i] - b[i]);
218274
}
219275
#endif
@@ -271,18 +327,34 @@ float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t leng
271327

272328
float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length) const
273329
{
330+
// Prefetch the start of both vectors
331+
_mm_prefetch((const char*)a, _MM_HINT_T0);
332+
_mm_prefetch((const char*)b, _MM_HINT_T0);
333+
274334
__m128 diff, v1, v2;
275335
__m128 sum = _mm_set1_ps(0);
336+
337+
const float *original_a = a;
338+
const float *original_b = b;
339+
uint32_t prefetch_offset = 64; // Prefetch 64 bytes ahead (16 floats)
276340

277341
while (length >= 4)
278342
{
343+
// Prefetch ahead for better cache performance
344+
if (length > prefetch_offset)
345+
{
346+
_mm_prefetch((const char*)(original_a + prefetch_offset), _MM_HINT_T0);
347+
_mm_prefetch((const char*)(original_b + prefetch_offset), _MM_HINT_T0);
348+
}
349+
279350
v1 = _mm_loadu_ps(a);
280351
a += 4;
281352
v2 = _mm_loadu_ps(b);
282353
b += 4;
283354
diff = _mm_sub_ps(v1, v2);
284355
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
285356
length -= 4;
357+
prefetch_offset += 4;
286358
}
287359

288360
return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3];

src/in_mem_data_store.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
#include "utils.h"
99

10+
#ifdef _WIN32
11+
#include <intrin.h>
12+
#else
13+
#include <x86intrin.h>
14+
#endif
15+
1016
namespace diskann
1117
{
1218

@@ -197,17 +203,38 @@ void InMemDataStore<data_t>::preprocess_query(const data_t *query, AbstractScrat
197203

198204
template <typename data_t> float InMemDataStore<data_t>::get_distance(const data_t *query, const location_t loc) const
199205
{
200-
return _distance_fn->compare(query, _data + _aligned_dim * loc, (uint32_t)_aligned_dim);
206+
// Prefetch query vector for better cache performance
207+
_mm_prefetch((const char*)query, _MM_HINT_T0);
208+
209+
// Prefetch the current vector data
210+
const data_t *vector_data = _data + _aligned_dim * loc;
211+
_mm_prefetch((const char*)vector_data, _MM_HINT_T0);
212+
213+
return _distance_fn->compare(query, vector_data, (uint32_t)_aligned_dim);
201214
}
202215

203216
template <typename data_t>
204217
void InMemDataStore<data_t>::get_distance(const data_t *query, const location_t *locations,
205218
const uint32_t location_count, float *distances,
206219
AbstractScratch<data_t> *scratch_space) const
207220
{
221+
// Prefetch query vector once for the entire batch
222+
_mm_prefetch((const char*)query, _MM_HINT_T0);
223+
208224
for (location_t i = 0; i < location_count; i++)
209225
{
210-
distances[i] = _distance_fn->compare(query, _data + locations[i] * _aligned_dim, (uint32_t)this->_aligned_dim);
226+
// Prefetch current vector data
227+
const data_t *current_vector = _data + locations[i] * _aligned_dim;
228+
_mm_prefetch((const char*)current_vector, _MM_HINT_T0);
229+
230+
// Prefetch next vector data if available
231+
if (i + 1 < location_count)
232+
{
233+
const data_t *next_vector = _data + locations[i + 1] * _aligned_dim;
234+
_mm_prefetch((const char*)next_vector, _MM_HINT_T0);
235+
}
236+
237+
distances[i] = _distance_fn->compare(query, current_vector, (uint32_t)this->_aligned_dim);
211238
}
212239
}
213240

@@ -222,10 +249,24 @@ template <typename data_t>
222249
void InMemDataStore<data_t>::get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
223250
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const
224251
{
252+
// Prefetch query vector once for the entire batch
253+
_mm_prefetch((const char*)preprocessed_query, _MM_HINT_T0);
254+
225255
for (int i = 0; i < ids.size(); i++)
226256
{
257+
// Prefetch current vector data
258+
const data_t *current_vector = _data + ids[i] * _aligned_dim;
259+
_mm_prefetch((const char*)current_vector, _MM_HINT_T0);
260+
261+
// Prefetch next vector data if available
262+
if (i + 1 < (int)ids.size())
263+
{
264+
const data_t *next_vector = _data + ids[i + 1] * _aligned_dim;
265+
_mm_prefetch((const char*)next_vector, _MM_HINT_T0);
266+
}
267+
227268
distances[i] =
228-
_distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim);
269+
_distance_fn->compare(preprocessed_query, current_vector, (uint32_t)this->_aligned_dim);
229270
}
230271
}
231272

src/index.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,8 @@ inline void Index<T, TagT, LabelT>::calculate_jaccard_similarity_batch(
10741074

10751075
// Check if any label in this clause matches the point
10761076
for (const auto& label : clause) {
1077-
if (_labels_to_points_set[label].count(point_id)) {
1077+
auto it = _labels_to_points_set.find(label);
1078+
if (it != _labels_to_points_set.end() && it->second.count(point_id)) {
10781079
matching_clauses++;
10791080
break; // Found match, move to next clause
10801081
}

src/pq_data_store.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
#include "utils.h"
77
#include "distance.h"
88

9+
#ifdef _WIN32
10+
#include <intrin.h>
11+
#else
12+
#include <x86intrin.h>
13+
#endif
14+
915
namespace diskann
1016
{
1117

@@ -180,6 +186,17 @@ void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const l
180186
{
181187
throw diskann::ANNException("PQScratch not set in scratch space.", -1);
182188
}
189+
190+
// Prefetch preprocessed query data
191+
_mm_prefetch((const char*)preprocessed_query, _MM_HINT_T0);
192+
193+
// Prefetch some of the quantized data for the locations we'll process
194+
for (uint32_t i = 0; i < location_count && i < 4; i++)
195+
{
196+
const uint8_t *quantized_loc = _quantized_data + locations[i] * this->_num_chunks;
197+
_mm_prefetch((const char*)quantized_loc, _MM_HINT_T0);
198+
}
199+
183200
diskann::aggregate_coords(locations, location_count, _quantized_data, this->_num_chunks,
184201
pq_scratch->aligned_pq_coord_scratch);
185202
_pq_distance_fn->preprocessed_distance(*pq_scratch, location_count, distances);
@@ -198,6 +215,18 @@ void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const s
198215
{
199216
throw diskann::ANNException("PQScratch not set in scratch space.", -1);
200217
}
218+
219+
// Prefetch preprocessed query data
220+
_mm_prefetch((const char*)preprocessed_query, _MM_HINT_T0);
221+
222+
// Prefetch some of the quantized data for the locations we'll process
223+
size_t prefetch_count = ids.size() < 4 ? ids.size() : 4;
224+
for (size_t i = 0; i < prefetch_count; i++)
225+
{
226+
const uint8_t *quantized_loc = _quantized_data + ids[i] * this->_num_chunks;
227+
_mm_prefetch((const char*)quantized_loc, _MM_HINT_T0);
228+
}
229+
201230
diskann::aggregate_coords(ids, _quantized_data, this->_num_chunks, pq_scratch->aligned_pq_coord_scratch);
202231
_pq_distance_fn->preprocessed_distance(*pq_scratch, (location_t)ids.size(), distances);
203232
}

0 commit comments

Comments
 (0)