Skip to content

Commit a4bd18d

Browse files
authored
Adding num_searchers to KnnIndexTester to simulate multiple callers (#130492)
1 parent f91124a commit a4bd18d

File tree

2 files changed

+87
-10
lines changed

2 files changed

+87
-10
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ record CmdLineArgs(
4242
int hnswM,
4343
int hnswEfConstruction,
4444
int searchThreads,
45+
int numSearchers,
4546
int indexThreads,
4647
boolean reindex,
4748
boolean forceMerge,
@@ -64,6 +65,7 @@ record CmdLineArgs(
6465
static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor");
6566
static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m");
6667
static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction");
68+
static final ParseField NUM_SEARCHERS_FIELD = new ParseField("num_searchers");
6769
static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads");
6870
static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads");
6971
static final ParseField REINDEX_FIELD = new ParseField("reindex");
@@ -95,6 +97,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
9597
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
9698
PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD);
9799
PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD);
100+
PARSER.declareInt(Builder::setNumSearchers, NUM_SEARCHERS_FIELD);
98101
PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD);
99102
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
100103
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
@@ -125,6 +128,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
125128
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
126129
builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction);
127130
builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads);
131+
builder.field(NUM_SEARCHERS_FIELD.getPreferredName(), numSearchers);
128132
builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads);
129133
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
130134
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
@@ -154,6 +158,7 @@ static class Builder {
154158
private int hnswM = 16;
155159
private int hnswEfConstruction = 200;
156160
private int searchThreads = 1;
161+
private int numSearchers = 1;
157162
private int indexThreads = 1;
158163
private boolean reindex = false;
159164
private boolean forceMerge = false;
@@ -228,6 +233,11 @@ public Builder setSearchThreads(int searchThreads) {
228233
return this;
229234
}
230235

236+
public Builder setNumSearchers(int numSearchers) {
237+
this.numSearchers = numSearchers;
238+
return this;
239+
}
240+
231241
public Builder setIndexThreads(int indexThreads) {
232242
this.indexThreads = indexThreads;
233243
return this;
@@ -291,6 +301,7 @@ public CmdLineArgs build() {
291301
hnswM,
292302
hnswEfConstruction,
293303
searchThreads,
304+
numSearchers,
294305
indexThreads,
295306
reindex,
296307
forceMerge,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555

5656
import java.io.IOException;
5757
import java.io.OutputStream;
58+
import java.io.UncheckedIOException;
5859
import java.nio.ByteBuffer;
5960
import java.nio.ByteOrder;
6061
import java.nio.IntBuffer;
@@ -71,7 +72,9 @@
7172
import java.util.concurrent.ExecutorService;
7273
import java.util.concurrent.Executors;
7374
import java.util.concurrent.ForkJoinPool;
75+
import java.util.concurrent.Future;
7476
import java.util.concurrent.TimeUnit;
77+
import java.util.function.IntConsumer;
7578

7679
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
7780
import static org.elasticsearch.test.knn.KnnIndexTester.logger;
@@ -96,6 +99,7 @@ class KnnSearcher {
9699
private final VectorEncoding vectorEncoding;
97100
private final float overSamplingFactor;
98101
private final int searchThreads;
102+
private final int numSearchers;
99103

100104
KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
101105
this.docPath = cmdLineArgs.docVectors();
@@ -115,6 +119,7 @@ class KnnSearcher {
115119
this.nProbe = nProbe;
116120
this.indexType = cmdLineArgs.indexType();
117121
this.searchThreads = cmdLineArgs.searchThreads();
122+
this.numSearchers = cmdLineArgs.numSearchers();
118123
}
119124

120125
void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
@@ -124,7 +129,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
124129
int offsetByteSize = 0;
125130
try (
126131
FileChannel input = FileChannel.open(queryPath);
127-
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
132+
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"));
133+
ExecutorService numSearchersExecutor = numSearchers > 1
134+
? Executors.newFixedThreadPool(numSearchers, r -> new Thread(r, "KnnSearcher-Caller"))
135+
: null
128136
) {
129137
long queryPathSizeInBytes = input.size();
130138
logger.info(
@@ -163,29 +171,87 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
163171
}
164172
}
165173
targetReader.reset();
174+
final IntConsumer[] queryConsumers = new IntConsumer[numSearchers];
175+
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
176+
byte[][] queries = new byte[numQueryVectors][dim];
177+
for (int i = 0; i < numQueryVectors; i++) {
178+
targetReader.next(queries[i]);
179+
}
180+
for (int s = 0; s < numSearchers; s++) {
181+
queryConsumers[s] = i -> {
182+
try {
183+
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
184+
} catch (IOException e) {
185+
throw new UncheckedIOException(e);
186+
}
187+
};
188+
}
189+
} else {
190+
float[][] queries = new float[numQueryVectors][dim];
191+
for (int i = 0; i < numQueryVectors; i++) {
192+
targetReader.next(queries[i]);
193+
}
194+
for (int s = 0; s < numSearchers; s++) {
195+
queryConsumers[s] = i -> {
196+
try {
197+
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
198+
} catch (IOException e) {
199+
throw new UncheckedIOException(e);
200+
}
201+
};
202+
}
203+
}
204+
int[][] querySplits = new int[numSearchers][];
205+
int queriesPerSearcher = numQueryVectors / numSearchers;
206+
for (int s = 0; s < numSearchers; s++) {
207+
int start = s * queriesPerSearcher;
208+
int end = (s == numSearchers - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher;
209+
querySplits[s] = new int[end - start];
210+
for (int i = start; i < end; i++) {
211+
querySplits[s][i - start] = i;
212+
}
213+
}
214+
targetReader.reset();
166215
startNS = System.nanoTime();
167216
KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails();
168-
for (int i = 0; i < numQueryVectors; i++) {
169-
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
170-
targetReader.next(targetBytes);
171-
results[i] = doVectorQuery(targetBytes, searcher, earlyTermination);
172-
} else {
173-
targetReader.next(target);
174-
results[i] = doVectorQuery(target, searcher, earlyTermination);
217+
if (numSearchersExecutor != null) {
218+
// use multiple searchers
219+
var futures = new ArrayList<Future<Void>>();
220+
for (int s = 0; s < numSearchers; s++) {
221+
int[] split = querySplits[s];
222+
IntConsumer queryConsumer = queryConsumers[s];
223+
futures.add(numSearchersExecutor.submit(() -> {
224+
for (int j : split) {
225+
queryConsumer.accept(j);
226+
}
227+
return null;
228+
}));
229+
}
230+
for (Future<Void> future : futures) {
231+
try {
232+
future.get();
233+
} catch (Exception e) {
234+
throw new RuntimeException("Error executing searcher thread", e);
235+
}
236+
}
237+
} else {
238+
// use a single searcher
239+
for (int i = 0; i < numQueryVectors; i++) {
240+
queryConsumers[0].accept(i);
175241
}
176242
}
177243
KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
178244
elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS);
179245
long startCPUTimeNS = 0;
180246
long endCPUTimeNS = 0;
181247
for (int i = 0; i < startThreadDetails.threadInfos.length; i++) {
182-
if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) {
248+
if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
183249
startCPUTimeNS += startThreadDetails.cpuTimesNS[i];
184250
}
185251
}
186252

187253
for (int i = 0; i < endThreadDetails.threadInfos.length; i++) {
188-
if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) {
254+
if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
189255
endCPUTimeNS += endThreadDetails.cpuTimesNS[i];
190256
}
191257
}

0 commit comments

Comments
 (0)