Skip to content

Adding num_searchers to KnnIndexTester to simulate multiple callers #130492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ record CmdLineArgs(
int hnswM,
int hnswEfConstruction,
int searchThreads,
int numSearchers,
int indexThreads,
boolean reindex,
boolean forceMerge,
Expand All @@ -64,6 +65,7 @@ record CmdLineArgs(
static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor");
static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m");
static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction");
static final ParseField NUM_SEARCHERS_FIELD = new ParseField("num_searchers");
static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads");
static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads");
static final ParseField REINDEX_FIELD = new ParseField("reindex");
Expand Down Expand Up @@ -95,6 +97,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD);
PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD);
PARSER.declareInt(Builder::setNumSearchers, NUM_SEARCHERS_FIELD);
PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD);
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
Expand Down Expand Up @@ -125,6 +128,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction);
builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads);
builder.field(NUM_SEARCHERS_FIELD.getPreferredName(), numSearchers);
builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads);
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
Expand Down Expand Up @@ -154,6 +158,7 @@ static class Builder {
private int hnswM = 16;
private int hnswEfConstruction = 200;
private int searchThreads = 1;
private int numSearchers = 1;
private int indexThreads = 1;
private boolean reindex = false;
private boolean forceMerge = false;
Expand Down Expand Up @@ -228,6 +233,11 @@ public Builder setSearchThreads(int searchThreads) {
return this;
}

public Builder setNumSearchers(int numSearchers) {
this.numSearchers = numSearchers;
return this;
}

public Builder setIndexThreads(int indexThreads) {
this.indexThreads = indexThreads;
return this;
Expand Down Expand Up @@ -291,6 +301,7 @@ public CmdLineArgs build() {
hnswM,
hnswEfConstruction,
searchThreads,
numSearchers,
indexThreads,
reindex,
forceMerge,
Expand Down
86 changes: 76 additions & 10 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
Expand All @@ -71,7 +72,9 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.function.IntConsumer;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.elasticsearch.test.knn.KnnIndexTester.logger;
Expand All @@ -96,6 +99,7 @@ class KnnSearcher {
private final VectorEncoding vectorEncoding;
private final float overSamplingFactor;
private final int searchThreads;
private final int numSearchers;

KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) {
this.docPath = cmdLineArgs.docVectors();
Expand All @@ -115,6 +119,7 @@ class KnnSearcher {
this.nProbe = nProbe;
this.indexType = cmdLineArgs.indexType();
this.searchThreads = cmdLineArgs.searchThreads();
this.numSearchers = cmdLineArgs.numSearchers();
}

void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException {
Expand All @@ -124,7 +129,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
int offsetByteSize = 0;
try (
FileChannel input = FileChannel.open(queryPath);
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"));
ExecutorService numSearchersExecutor = numSearchers > 1
? Executors.newFixedThreadPool(numSearchers, r -> new Thread(r, "KnnSearcher-Caller"))
: null
) {
long queryPathSizeInBytes = input.size();
logger.info(
Expand Down Expand Up @@ -163,29 +171,87 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
}
}
targetReader.reset();
final IntConsumer[] queryConsumers = new IntConsumer[numSearchers];
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[][] queries = new byte[numQueryVectors][dim];
for (int i = 0; i < numQueryVectors; i++) {
targetReader.next(queries[i]);
}
for (int s = 0; s < numSearchers; s++) {
queryConsumers[s] = i -> {
try {
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
} else {
float[][] queries = new float[numQueryVectors][dim];
for (int i = 0; i < numQueryVectors; i++) {
targetReader.next(queries[i]);
}
for (int s = 0; s < numSearchers; s++) {
queryConsumers[s] = i -> {
try {
results[i] = doVectorQuery(queries[i], searcher, earlyTermination);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
}
int[][] querySplits = new int[numSearchers][];
int queriesPerSearcher = numQueryVectors / numSearchers;
for (int s = 0; s < numSearchers; s++) {
int start = s * queriesPerSearcher;
int end = (s == numSearchers - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher;
querySplits[s] = new int[end - start];
for (int i = start; i < end; i++) {
querySplits[s][i - start] = i;
}
}
targetReader.reset();
startNS = System.nanoTime();
KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails();
for (int i = 0; i < numQueryVectors; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
targetReader.next(targetBytes);
results[i] = doVectorQuery(targetBytes, searcher, earlyTermination);
} else {
targetReader.next(target);
results[i] = doVectorQuery(target, searcher, earlyTermination);
if (numSearchersExecutor != null) {
// use multiple searchers
var futures = new ArrayList<Future<Void>>();
for (int s = 0; s < numSearchers; s++) {
int[] split = querySplits[s];
IntConsumer queryConsumer = queryConsumers[s];
futures.add(numSearchersExecutor.submit(() -> {
for (int j : split) {
queryConsumer.accept(j);
}
return null;
}));
}
for (Future<Void> future : futures) {
try {
future.get();
} catch (Exception e) {
throw new RuntimeException("Error executing searcher thread", e);
}
}
} else {
// use a single searcher
for (int i = 0; i < numQueryVectors; i++) {
queryConsumers[0].accept(i);
}
}
KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails();
elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS);
long startCPUTimeNS = 0;
long endCPUTimeNS = 0;
for (int i = 0; i < startThreadDetails.threadInfos.length; i++) {
if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) {
if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
startCPUTimeNS += startThreadDetails.cpuTimesNS[i];
}
}

for (int i = 0; i < endThreadDetails.threadInfos.length; i++) {
if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) {
if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) {
endCPUTimeNS += endThreadDetails.cpuTimesNS[i];
}
}
Expand Down