55
55
56
56
import java .io .IOException ;
57
57
import java .io .OutputStream ;
58
+ import java .io .UncheckedIOException ;
58
59
import java .nio .ByteBuffer ;
59
60
import java .nio .ByteOrder ;
60
61
import java .nio .IntBuffer ;
71
72
import java .util .concurrent .ExecutorService ;
72
73
import java .util .concurrent .Executors ;
73
74
import java .util .concurrent .ForkJoinPool ;
75
+ import java .util .concurrent .Future ;
74
76
import java .util .concurrent .TimeUnit ;
77
+ import java .util .function .IntConsumer ;
75
78
76
79
import static org .apache .lucene .search .DocIdSetIterator .NO_MORE_DOCS ;
77
80
import static org .elasticsearch .test .knn .KnnIndexTester .logger ;
@@ -96,6 +99,7 @@ class KnnSearcher {
96
99
private final VectorEncoding vectorEncoding ;
97
100
private final float overSamplingFactor ;
98
101
private final int searchThreads ;
102
+ private final int numSearchers ;
99
103
100
104
KnnSearcher (Path indexPath , CmdLineArgs cmdLineArgs , int nProbe ) {
101
105
this .docPath = cmdLineArgs .docVectors ();
@@ -115,6 +119,7 @@ class KnnSearcher {
115
119
this .nProbe = nProbe ;
116
120
this .indexType = cmdLineArgs .indexType ();
117
121
this .searchThreads = cmdLineArgs .searchThreads ();
122
+ this .numSearchers = cmdLineArgs .numSearchers ();
118
123
}
119
124
120
125
void runSearch (KnnIndexTester .Results finalResults , boolean earlyTermination ) throws IOException {
@@ -124,7 +129,10 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
124
129
int offsetByteSize = 0 ;
125
130
try (
126
131
FileChannel input = FileChannel .open (queryPath );
127
- ExecutorService executorService = Executors .newFixedThreadPool (searchThreads , r -> new Thread (r , "KnnSearcher-Thread" ))
132
+ ExecutorService executorService = Executors .newFixedThreadPool (searchThreads , r -> new Thread (r , "KnnSearcher-Thread" ));
133
+ ExecutorService numSearchersExecutor = numSearchers > 1
134
+ ? Executors .newFixedThreadPool (numSearchers , r -> new Thread (r , "KnnSearcher-Caller" ))
135
+ : null
128
136
) {
129
137
long queryPathSizeInBytes = input .size ();
130
138
logger .info (
@@ -163,29 +171,87 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
163
171
}
164
172
}
165
173
targetReader .reset ();
174
+ final IntConsumer [] queryConsumers = new IntConsumer [numSearchers ];
175
+ if (vectorEncoding .equals (VectorEncoding .BYTE )) {
176
+ byte [][] queries = new byte [numQueryVectors ][dim ];
177
+ for (int i = 0 ; i < numQueryVectors ; i ++) {
178
+ targetReader .next (queries [i ]);
179
+ }
180
+ for (int s = 0 ; s < numSearchers ; s ++) {
181
+ queryConsumers [s ] = i -> {
182
+ try {
183
+ results [i ] = doVectorQuery (queries [i ], searcher , earlyTermination );
184
+ } catch (IOException e ) {
185
+ throw new UncheckedIOException (e );
186
+ }
187
+ };
188
+ }
189
+ } else {
190
+ float [][] queries = new float [numQueryVectors ][dim ];
191
+ for (int i = 0 ; i < numQueryVectors ; i ++) {
192
+ targetReader .next (queries [i ]);
193
+ }
194
+ for (int s = 0 ; s < numSearchers ; s ++) {
195
+ queryConsumers [s ] = i -> {
196
+ try {
197
+ results [i ] = doVectorQuery (queries [i ], searcher , earlyTermination );
198
+ } catch (IOException e ) {
199
+ throw new UncheckedIOException (e );
200
+ }
201
+ };
202
+ }
203
+ }
204
+ int [][] querySplits = new int [numSearchers ][];
205
+ int queriesPerSearcher = numQueryVectors / numSearchers ;
206
+ for (int s = 0 ; s < numSearchers ; s ++) {
207
+ int start = s * queriesPerSearcher ;
208
+ int end = (s == numSearchers - 1 ) ? numQueryVectors : (s + 1 ) * queriesPerSearcher ;
209
+ querySplits [s ] = new int [end - start ];
210
+ for (int i = start ; i < end ; i ++) {
211
+ querySplits [s ][i - start ] = i ;
212
+ }
213
+ }
214
+ targetReader .reset ();
166
215
startNS = System .nanoTime ();
167
216
KnnIndexTester .ThreadDetails startThreadDetails = new KnnIndexTester .ThreadDetails ();
168
- for (int i = 0 ; i < numQueryVectors ; i ++) {
169
- if (vectorEncoding .equals (VectorEncoding .BYTE )) {
170
- targetReader .next (targetBytes );
171
- results [i ] = doVectorQuery (targetBytes , searcher , earlyTermination );
172
- } else {
173
- targetReader .next (target );
174
- results [i ] = doVectorQuery (target , searcher , earlyTermination );
217
+ if (numSearchersExecutor != null ) {
218
+ // use multiple searchers
219
+ var futures = new ArrayList <Future <Void >>();
220
+ for (int s = 0 ; s < numSearchers ; s ++) {
221
+ int [] split = querySplits [s ];
222
+ IntConsumer queryConsumer = queryConsumers [s ];
223
+ futures .add (numSearchersExecutor .submit (() -> {
224
+ for (int j : split ) {
225
+ queryConsumer .accept (j );
226
+ }
227
+ return null ;
228
+ }));
229
+ }
230
+ for (Future <Void > future : futures ) {
231
+ try {
232
+ future .get ();
233
+ } catch (Exception e ) {
234
+ throw new RuntimeException ("Error executing searcher thread" , e );
235
+ }
236
+ }
237
+ } else {
238
+ // use a single searcher
239
+ for (int i = 0 ; i < numQueryVectors ; i ++) {
240
+ queryConsumers [0 ].accept (i );
175
241
}
176
242
}
177
243
KnnIndexTester .ThreadDetails endThreadDetails = new KnnIndexTester .ThreadDetails ();
178
244
elapsed = TimeUnit .NANOSECONDS .toMillis (System .nanoTime () - startNS );
179
245
long startCPUTimeNS = 0 ;
180
246
long endCPUTimeNS = 0 ;
181
247
for (int i = 0 ; i < startThreadDetails .threadInfos .length ; i ++) {
182
- if (startThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher-Thread " )) {
248
+ if (startThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher" )) {
183
249
startCPUTimeNS += startThreadDetails .cpuTimesNS [i ];
184
250
}
185
251
}
186
252
187
253
for (int i = 0 ; i < endThreadDetails .threadInfos .length ; i ++) {
188
- if (endThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher-Thread " )) {
254
+ if (endThreadDetails .threadInfos [i ].getThreadName ().startsWith ("KnnSearcher" )) {
189
255
endCPUTimeNS += endThreadDetails .cpuTimesNS [i ];
190
256
}
191
257
}
0 commit comments