Skip to content

Commit 7fec052

Browse files
committed
Add a direct IO option to rescore_vector for bbq_hnsw
1 parent 3e47504 commit 7fec052

File tree

6 files changed

+156
-94
lines changed

6 files changed

+156
-94
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ static Codec createCodec(CmdLineArgs args) {
9595
if (args.indexType() == IndexType.FLAT) {
9696
format = new ES818BinaryQuantizedVectorsFormat();
9797
} else {
98-
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, null);
98+
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM(), args.hnswEfConstruction(), 1, false, null);
9999
}
100100
} else if (args.quantizeBits() < 32) {
101101
if (args.indexType() == IndexType.FLAT) {

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@
8787
*/
8888
public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
8989

90-
public static final boolean USE_DIRECT_IO = Boolean.parseBoolean(System.getProperty("vector.rescoring.directio", "false"));
91-
9290
public static final String BINARIZED_VECTOR_COMPONENT = "BVEC";
9391
public static final String NAME = "ES818BinaryQuantizedVectorsFormat";
9492

@@ -100,17 +98,24 @@ public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
10098
static final String VECTOR_DATA_EXTENSION = "veb";
10199
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
102100

103-
private static final FlatVectorsFormat rawVectorFormat = USE_DIRECT_IO
104-
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
105-
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
106-
107101
private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
108102
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
109103
);
110104

105+
private final FlatVectorsFormat rawVectorFormat;
106+
111107
/** Creates a new instance with the default number of vectors per cluster. */
112108
public ES818BinaryQuantizedVectorsFormat() {
109+
this(false);
110+
}
111+
112+
/** Creates a new instance with the default number of vectors per cluster,
113+
* and whether direct IO should be used to access raw vectors. */
114+
public ES818BinaryQuantizedVectorsFormat(boolean useDirectIO) {
113115
super(NAME);
116+
rawVectorFormat = useDirectIO
117+
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
118+
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
114119
}
115120

116121
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
6262
private final int beamWidth;
6363

6464
/** The format for storing, reading, merging vectors on disk */
65-
private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat();
65+
private final FlatVectorsFormat flatVectorsFormat;
6666

6767
private final int numMergeWorkers;
6868
private final TaskExecutor mergeExec;
6969

7070
/** Constructs a format using default graph construction parameters */
7171
public ES818HnswBinaryQuantizedVectorsFormat() {
72-
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
72+
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, false, null);
7373
}
7474

7575
/**
@@ -79,7 +79,18 @@ public ES818HnswBinaryQuantizedVectorsFormat() {
7979
* @param beamWidth the size of the queue maintained during graph construction.
8080
*/
8181
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
82-
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
82+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, false, null);
83+
}
84+
85+
/**
86+
* Constructs a format using the given graph construction parameters.
87+
*
88+
* @param maxConn the maximum number of connections to a node in the HNSW graph
89+
* @param beamWidth the size of the queue maintained during graph construction.
90+
* @param useDirectIO whether direct IO should be used to access raw vectors
91+
*/
92+
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useDirectIO) {
93+
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, useDirectIO, null);
8394
}
8495

8596
/**
@@ -92,7 +103,13 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
92103
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
93104
* generated by this format to do the merge
94105
*/
95-
public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
106+
public ES818HnswBinaryQuantizedVectorsFormat(
107+
int maxConn,
108+
int beamWidth,
109+
int numMergeWorkers,
110+
boolean useDirectIO,
111+
ExecutorService mergeExec
112+
) {
96113
super(NAME);
97114
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
98115
throw new IllegalArgumentException(
@@ -110,6 +127,9 @@ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num
110127
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
111128
}
112129
this.numMergeWorkers = numMergeWorkers;
130+
131+
flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat(useDirectIO);
132+
113133
if (mergeExec != null) {
114134
this.mergeExec = new TaskExecutor(mergeExec);
115135
} else {

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ private DenseVectorIndexOptions defaultIndexOptions(boolean defaultInt8Hnsw, boo
387387
return new BBQHnswIndexOptions(
388388
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
389389
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
390-
new RescoreVector(DEFAULT_OVERSAMPLE)
390+
null
391391
);
392392
} else if (defaultInt8Hnsw) {
393393
return new Int8HnswIndexOptions(
@@ -1632,9 +1632,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16321632
RescoreVector rescoreVector = null;
16331633
if (hasRescoreIndexVersion(indexVersion)) {
16341634
rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1635-
if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) {
1636-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1637-
}
16381635
}
16391636
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
16401637
return new BBQHnswIndexOptions(m, efConstruction, rescoreVector);
@@ -1656,9 +1653,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16561653
RescoreVector rescoreVector = null;
16571654
if (hasRescoreIndexVersion(indexVersion)) {
16581655
rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1659-
if (rescoreVector == null && defaultOversampleForBBQ(indexVersion)) {
1660-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1661-
}
16621656
}
16631657
MappingParser.checkNoRemainingFields(fieldName, indexOptionsMap);
16641658
return new BBQFlatIndexOptions(rescoreVector);
@@ -1693,9 +1687,6 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
16931687
}
16941688
}
16951689
RescoreVector rescoreVector = RescoreVector.fromIndexOptions(indexOptionsMap, indexVersion);
1696-
if (rescoreVector == null) {
1697-
rescoreVector = new RescoreVector(DEFAULT_OVERSAMPLE);
1698-
}
16991690
Object nProbeNode = indexOptionsMap.remove("default_n_probe");
17001691
int nProbe = -1;
17011692
if (nProbeNode != null) {
@@ -2183,7 +2174,8 @@ public BBQHnswIndexOptions(int m, int efConstruction, RescoreVector rescoreVecto
21832174
@Override
21842175
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
21852176
assert elementType == ElementType.FLOAT;
2186-
return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction);
2177+
boolean directIO = rescoreVector.useDirectIO != null && rescoreVector.useDirectIO;
2178+
return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction, directIO);
21872179
}
21882180

21892181
@Override
@@ -2342,36 +2334,44 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
23422334
}
23432335
}
23442336

2345-
public record RescoreVector(float oversample) implements ToXContentObject {
2337+
public record RescoreVector(Float oversample, Boolean useDirectIO) implements ToXContentObject {
23462338
static final String NAME = "rescore_vector";
23472339
static final String OVERSAMPLE = "oversample";
2340+
static final String DIRECT_IO = "direct_io";
23482341

23492342
static RescoreVector fromIndexOptions(Map<String, ?> indexOptionsMap, IndexVersion indexVersion) {
23502343
Object rescoreVectorNode = indexOptionsMap.remove(NAME);
23512344
if (rescoreVectorNode == null) {
23522345
return null;
23532346
}
23542347
Map<String, Object> mappedNode = XContentMapValues.nodeMapValue(rescoreVectorNode, NAME);
2348+
2349+
Float oversampleValue = null;
23552350
Object oversampleNode = mappedNode.get(OVERSAMPLE);
2356-
if (oversampleNode == null) {
2357-
throw new IllegalArgumentException("Invalid rescore_vector value. Missing required field " + OVERSAMPLE);
2358-
}
2359-
float oversampleValue = (float) XContentMapValues.nodeDoubleValue(oversampleNode);
2360-
if (oversampleValue == 0 && allowsZeroRescore(indexVersion) == false) {
2361-
throw new IllegalArgumentException("oversample must be greater than 1");
2362-
}
2363-
if (oversampleValue < 1 && oversampleValue != 0) {
2364-
throw new IllegalArgumentException("oversample must be greater than 1 or exactly 0");
2365-
} else if (oversampleValue > 10) {
2366-
throw new IllegalArgumentException("oversample must be less than or equal to 10");
2351+
if (oversampleNode != null) {
2352+
oversampleValue = (float) XContentMapValues.nodeDoubleValue(oversampleNode);
2353+
if (oversampleValue == 0 && allowsZeroRescore(indexVersion) == false) {
2354+
throw new IllegalArgumentException("oversample must be greater than 1");
2355+
}
2356+
if (oversampleValue < 1 && oversampleValue != 0) {
2357+
throw new IllegalArgumentException("oversample must be greater than 1 or exactly 0");
2358+
} else if (oversampleValue > 10) {
2359+
throw new IllegalArgumentException("oversample must be less than or equal to 10");
2360+
}
23672361
}
2368-
return new RescoreVector(oversampleValue);
2362+
2363+
Boolean directIO = (Boolean) mappedNode.get(DIRECT_IO);
2364+
2365+
return new RescoreVector(oversampleValue, directIO);
23692366
}
23702367

23712368
@Override
23722369
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
23732370
builder.startObject(NAME);
23742371
builder.field(OVERSAMPLE, oversample);
2372+
if (useDirectIO) {
2373+
builder.field(DIRECT_IO, useDirectIO);
2374+
}
23752375
builder.endObject();
23762376
return builder;
23772377
}
@@ -2710,6 +2710,10 @@ && isNotUnitVector(squaredMagnitude)) {
27102710
&& quantizedIndexOptions.rescoreVector != null) {
27112711
oversample = quantizedIndexOptions.rescoreVector.oversample;
27122712
}
2713+
if (oversample == null) {
2714+
oversample = DEFAULT_OVERSAMPLE;
2715+
}
2716+
27132717
boolean rescore = needsRescore(oversample);
27142718
if (rescore) {
27152719
// Will get k * oversample for rescoring, and get the top k
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es818;
11+
12+
import org.apache.lucene.codecs.Codec;
13+
import org.apache.lucene.misc.store.DirectIODirectory;
14+
import org.apache.lucene.store.Directory;
15+
import org.apache.lucene.store.FSDirectory;
16+
import org.apache.lucene.store.IOContext;
17+
import org.apache.lucene.store.IndexOutput;
18+
import org.apache.lucene.tests.store.MockDirectoryWrapper;
19+
import org.apache.lucene.tests.util.TestUtil;
20+
import org.elasticsearch.common.settings.Settings;
21+
import org.elasticsearch.index.IndexModule;
22+
import org.elasticsearch.index.IndexSettings;
23+
import org.elasticsearch.index.shard.ShardId;
24+
import org.elasticsearch.index.shard.ShardPath;
25+
import org.elasticsearch.index.store.FsDirectoryFactory;
26+
import org.elasticsearch.test.IndexSettingsModule;
27+
import org.junit.BeforeClass;
28+
29+
import java.io.IOException;
30+
import java.nio.file.Files;
31+
import java.nio.file.Path;
32+
import java.util.Locale;
33+
import java.util.OptionalLong;
34+
35+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
36+
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
37+
38+
public class ES818DirectIOHnswBinaryQuantizedVectorsFormatTests extends ES818HnswBinaryQuantizedVectorsFormatTests {
39+
40+
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(
41+
new ES818HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, true)
42+
);
43+
44+
@Override
45+
protected Codec getCodec() {
46+
return codec;
47+
}
48+
49+
@BeforeClass
50+
public static void checkDirectIOSupport() {
51+
Path path = createTempDir("directIOProbe");
52+
try (Directory dir = open(path); IndexOutput out = dir.createOutput("out", IOContext.DEFAULT)) {
53+
out.writeString("test");
54+
} catch (IOException e) {
55+
assumeNoException("test requires a filesystem that supports Direct IO", e);
56+
}
57+
}
58+
59+
static DirectIODirectory open(Path path) throws IOException {
60+
return new DirectIODirectory(FSDirectory.open(path)) {
61+
@Override
62+
protected boolean useDirectIO(String name, IOContext context, OptionalLong fileLength) {
63+
return true;
64+
}
65+
};
66+
}
67+
68+
@Override
69+
public void testSimpleOffHeapSize() throws IOException {
70+
var config = newIndexWriterConfig().setUseCompoundFile(false); // avoid compound files to allow directIO
71+
try (Directory dir = newFSDirectory()) {
72+
testSimpleOffHeapSizeImpl(dir, config, false);
73+
}
74+
}
75+
76+
private Directory newFSDirectory() throws IOException {
77+
Settings settings = Settings.builder()
78+
.put(IndexModule.INDEX_STORE_TYPE_SETTING.getKey(), IndexModule.Type.HYBRIDFS.name().toLowerCase(Locale.ROOT))
79+
.build();
80+
IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings);
81+
Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0");
82+
Files.createDirectories(tempDir);
83+
ShardPath path = new ShardPath(false, tempDir, tempDir, new ShardId(idxSettings.getIndex(), 0));
84+
Directory dir = (new FsDirectoryFactory()).newDirectory(idxSettings, path);
85+
if (random().nextBoolean()) {
86+
dir = new MockDirectoryWrapper(random(), dir);
87+
}
88+
return dir;
89+
}
90+
}

0 commit comments

Comments
 (0)