From 58d20c8ad425930782df82ae2ede799749d5c1e5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 2 Jul 2025 16:28:47 -0400 Subject: [PATCH] Adding num_searchers to KnnIndexTester to simulate multiple callers --- .../elasticsearch/test/knn/CmdLineArgs.java | 11 +++ .../elasticsearch/test/knn/KnnSearcher.java | 86 ++++++++++++++++--- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index 75497c4fcc392..204b70385641e 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -42,6 +42,7 @@ record CmdLineArgs( int hnswM, int hnswEfConstruction, int searchThreads, + int numSearchers, int indexThreads, boolean reindex, boolean forceMerge, @@ -64,6 +65,7 @@ record CmdLineArgs( static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m"); static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction"); + static final ParseField NUM_SEARCHERS_FIELD = new ParseField("num_searchers"); static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads"); static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads"); static final ParseField REINDEX_FIELD = new ParseField("reindex"); @@ -95,6 +97,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD); PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD); + PARSER.declareInt(Builder::setNumSearchers, NUM_SEARCHERS_FIELD); PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD); PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); @@ -125,6 +128,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction); builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads); + builder.field(NUM_SEARCHERS_FIELD.getPreferredName(), numSearchers); builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads); builder.field(REINDEX_FIELD.getPreferredName(), reindex); builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); @@ -154,6 +158,7 @@ static class Builder { private int hnswM = 16; private int hnswEfConstruction = 200; private int searchThreads = 1; + private int numSearchers = 1; private int indexThreads = 1; private boolean reindex = false; private boolean forceMerge = false; @@ -228,6 +233,11 @@ public Builder setSearchThreads(int searchThreads) { return this; } + public Builder setNumSearchers(int numSearchers) { + this.numSearchers = numSearchers; + return this; + } + public Builder setIndexThreads(int indexThreads) { this.indexThreads = indexThreads; return this; @@ -291,6 +301,7 @@ public CmdLineArgs build() { hnswM, hnswEfConstruction, searchThreads, + numSearchers, indexThreads, reindex, forceMerge, diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java index 7967797e797f9..7cf8a5846cba3 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java @@ -55,6 +55,7 @@ import java.io.IOException; import java.io.OutputStream; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.IntBuffer; @@ -71,7 +72,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.function.IntConsumer; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.elasticsearch.test.knn.KnnIndexTester.logger; @@ -96,6 +99,7 @@ class KnnSearcher { private final VectorEncoding vectorEncoding; private final float overSamplingFactor; private final int searchThreads; + private final int numSearchers; KnnSearcher(Path indexPath, CmdLineArgs cmdLineArgs, int nProbe) { this.docPath = cmdLineArgs.docVectors(); @@ -115,6 +119,7 @@ class KnnSearcher { this.nProbe = nProbe; this.indexType = cmdLineArgs.indexType(); this.searchThreads = cmdLineArgs.searchThreads(); + this.numSearchers = cmdLineArgs.numSearchers(); } void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) throws IOException { @@ -124,7 +129,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th int offsetByteSize = 0; try ( FileChannel input = FileChannel.open(queryPath); - ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread")) + ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread")); + ExecutorService numSearchersExecutor = numSearchers > 1 + ? Executors.newFixedThreadPool(numSearchers, r -> new Thread(r, "KnnSearcher-Caller")) + : null ) { long queryPathSizeInBytes = input.size(); logger.info( @@ -163,15 +171,73 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th } } targetReader.reset(); + final IntConsumer[] queryConsumers = new IntConsumer[numSearchers]; + if (vectorEncoding.equals(VectorEncoding.BYTE)) { + byte[][] queries = new byte[numQueryVectors][dim]; + for (int i = 0; i < numQueryVectors; i++) { + targetReader.next(queries[i]); + } + for (int s = 0; s < numSearchers; s++) { + queryConsumers[s] = i -> { + try { + results[i] = doVectorQuery(queries[i], searcher, earlyTermination); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + } else { + float[][] queries = new float[numQueryVectors][dim]; + for (int i = 0; i < numQueryVectors; i++) { + targetReader.next(queries[i]); + } + for (int s = 0; s < numSearchers; s++) { + queryConsumers[s] = i -> { + try { + results[i] = doVectorQuery(queries[i], searcher, earlyTermination); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + } + int[][] querySplits = new int[numSearchers][]; + int queriesPerSearcher = numQueryVectors / numSearchers; + for (int s = 0; s < numSearchers; s++) { + int start = s * queriesPerSearcher; + int end = (s == numSearchers - 1) ? numQueryVectors : (s + 1) * queriesPerSearcher; + querySplits[s] = new int[end - start]; + for (int i = start; i < end; i++) { + querySplits[s][i - start] = i; + } + } + targetReader.reset(); startNS = System.nanoTime(); KnnIndexTester.ThreadDetails startThreadDetails = new KnnIndexTester.ThreadDetails(); - for (int i = 0; i < numQueryVectors; i++) { - if (vectorEncoding.equals(VectorEncoding.BYTE)) { - targetReader.next(targetBytes); - results[i] = doVectorQuery(targetBytes, searcher, earlyTermination); - } else { - targetReader.next(target); - results[i] = doVectorQuery(target, searcher, earlyTermination); + if (numSearchersExecutor != null) { + // use multiple searchers + var futures = new ArrayList>(); + for (int s = 0; s < numSearchers; s++) { + int[] split = querySplits[s]; + IntConsumer queryConsumer = queryConsumers[s]; + futures.add(numSearchersExecutor.submit(() -> { + for (int j : split) { + queryConsumer.accept(j); + } + return null; + })); + } + for (Future future : futures) { + try { + future.get(); + } catch (Exception e) { + throw new RuntimeException("Error executing searcher thread", e); + } + } + } else { + // use a single searcher + for (int i = 0; i < numQueryVectors; i++) { + queryConsumers[0].accept(i); } } KnnIndexTester.ThreadDetails endThreadDetails = new KnnIndexTester.ThreadDetails(); @@ -179,13 +245,13 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th long startCPUTimeNS = 0; long endCPUTimeNS = 0; for (int i = 0; i < startThreadDetails.threadInfos.length; i++) { - if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) { + if (startThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { startCPUTimeNS += startThreadDetails.cpuTimesNS[i]; } } for (int i = 0; i < endThreadDetails.threadInfos.length; i++) { - if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher-Thread")) { + if (endThreadDetails.threadInfos[i].getThreadName().startsWith("KnnSearcher")) { endCPUTimeNS += endThreadDetails.cpuTimesNS[i]; } }