diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java new file mode 100644 index 0000000000000..3d45e30d7abef --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; + +@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) +@Warmup(iterations = 3, time = 3) +@Measurement(iterations = 5, time = 3) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Thread) +/** + * Benchmark that compares various float32 vector similarity function + * implementations;: scalar, lucene's panama-ized, and Elasticsearch's native. + * Run with ./gradlew -p benchmarks run --args 'Float32ScorerBenchmark' + */ +public class Float32ScorerBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + if (supportsHeapSegments() == false) { + final Logger LOG = LogManager.getLogger(Float32ScorerBenchmark.class); + LOG.warn("*Query targets cannot run on " + "JDK " + Runtime.version()); + } + } + + @Param({ "96", "768", "1024" }) + public int dims; + final int size = 3; // there are only two vectors to compare against + + Directory dir; + IndexInput in; + VectorScorerFactory factory; + + float[] vec1, vec2, vec3; + + UpdateableRandomVectorScorer luceneDotScorer; + UpdateableRandomVectorScorer luceneSqrScorer; + UpdateableRandomVectorScorer nativeDotScorer; + UpdateableRandomVectorScorer nativeSqrScorer; + + RandomVectorScorer luceneDotScorerQuery; + RandomVectorScorer nativeDotScorerQuery; + RandomVectorScorer luceneSqrScorerQuery; + RandomVectorScorer nativeSqrScorerQuery; + + @Setup + public void setup() throws IOException { + var optionalVectorScorerFactory = VectorScorerFactory.instance(); + if (optionalVectorScorerFactory.isEmpty()) { + String msg = "JDK=[" + + Runtime.version() + + "], os.name=[" + + System.getProperty("os.name") + + "], os.arch=[" + + System.getProperty("os.arch") + + "]"; + throw new AssertionError("Vector scorer factory not present. Cannot run the benchmark. " + msg); + } + factory = optionalVectorScorerFactory.get(); + vec1 = randomFloatArray(dims); + vec2 = randomFloatArray(dims); + vec3 = randomFloatArray(dims); + + dir = new MMapDirectory(Files.createTempDirectory("nativeFloat32Bench")); + try (IndexOutput out = dir.createOutput("vector32.data", IOContext.DEFAULT)) { + writeFloat32Vectors(out, vec1, vec2, vec3); + } + in = dir.openInput("vector32.data", IOContext.DEFAULT); + var values = vectorValues(dims, 3, in, VectorSimilarityFunction.DOT_PRODUCT); + luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer(); + luceneDotScorer.setScoringOrdinal(0); + values = vectorValues(dims, 3, in, VectorSimilarityFunction.EUCLIDEAN); + luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(); + luceneSqrScorer.setScoringOrdinal(0); + + nativeDotScorer = factory.getFloat32VectorScorerSupplier(DOT_PRODUCT, in, values).get().scorer(); + nativeDotScorer.setScoringOrdinal(0); + nativeSqrScorer = factory.getFloat32VectorScorerSupplier(EUCLIDEAN, in, values).get().scorer(); + nativeSqrScorer.setScoringOrdinal(0); + + // setup for getFloat32VectorScorer / query vector scoring + if (supportsHeapSegments()) { + float[] queryVec = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVec[i] = ThreadLocalRandom.current().nextFloat(); + } + luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); + nativeDotScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); + luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); + nativeSqrScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); + } + } + + @TearDown + public void teardown() throws IOException { + IOUtils.close(dir, in); + } + + // we score against two different ords to avoid the lastOrd cache in vector values + @Benchmark + public float dotProductLucene() throws IOException { + return luceneDotScorer.score(1) + luceneDotScorer.score(2); + } + + @Benchmark + public float dotProductNative() throws IOException { + return nativeDotScorer.score(1) + nativeDotScorer.score(2); + } + + @Benchmark + public float dotProductScalar() { + return dotProductScalarImpl(vec1, vec2) + dotProductScalarImpl(vec1, vec3); + } + + @Benchmark + public float dotProductLuceneQuery() throws IOException { + return luceneDotScorerQuery.score(1) + luceneDotScorerQuery.score(2); + } + + @Benchmark + public float dotProductNativeQuery() throws IOException { + return nativeDotScorerQuery.score(1) + nativeDotScorerQuery.score(2); + } + + // -- square distance + + @Benchmark + public float squareDistanceLucene() throws IOException { + return luceneSqrScorer.score(1) + luceneSqrScorer.score(2); + } + + @Benchmark + public float squareDistanceNative() throws IOException { + return nativeSqrScorer.score(1) + nativeSqrScorer.score(2); + } + + @Benchmark + public float squareDistanceScalar() { + return squareDistanceScalarImpl(vec1, vec2) + squareDistanceScalarImpl(vec1, vec3); + } + + @Benchmark + public float squareDistanceLuceneQuery() throws IOException { + return luceneSqrScorerQuery.score(1) + luceneSqrScorerQuery.score(2); + } + + @Benchmark + public float squareDistanceNativeQuery() throws IOException { + return nativeSqrScorerQuery.score(1) + nativeSqrScorerQuery.score(2); + } + + static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } + + static float dotProductScalarImpl(float[] vec1, float[] vec2) { + float dot = 0; + for (int i = 0; i < vec1.length; i++) { + dot += vec1[i] * vec2[i]; + } + return Math.max((1 + dot) / 2, 0); + } + + static float squareDistanceScalarImpl(float[] vec1, float[] vec2) { + float dst = 0; + for (int i = 0; i < vec1.length; i++) { + float diff = vec1[i] - vec2[i]; + dst += diff * diff; + } + return 1 / (1f + dst); + } + + FloatVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + var slice = in.slice("values", 0, in.length()); + var byteSize = dims * Float.BYTES; + return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, slice, byteSize, DefaultFlatVectorScorer.INSTANCE, sim); + } + + RandomVectorScorerSupplier luceneScoreSupplier(FloatVectorValues values, VectorSimilarityFunction sim) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorerSupplier(sim, values); + } + + RandomVectorScorer luceneScorer(FloatVectorValues values, VectorSimilarityFunction sim, float[] queryVec) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorer(sim, values, queryVec); + } + + static void writeFloat32Vectors(IndexOutput out, float[]... vectors) throws IOException { + var buffer = ByteBuffer.allocate(vectors[0].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (var v : vectors) { + buffer.asFloatBuffer().put(v); + out.writeBytes(buffer.array(), buffer.array().length); + } + } + + static float[] randomFloatArray(int length) { + var random = ThreadLocalRandom.current(); + float[] fa = new float[length]; + for (int i = 0; i < length; i++) { + fa[i] = random.nextFloat(); + } + return fa; + } +} diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java new file mode 100644 index 0000000000000..cd0088b9f8a13 --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class Float32ScorerBenchmarkTests extends ESTestCase { + + final double delta = 1e-3; + final int dims; + + public Float32ScorerBenchmarkTests(int dims) { + this.dims = dims; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + public void testDotProduct() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Float32ScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.dotProductScalar(); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductNative(), delta); + + if (Float32ScorerBenchmark.supportsHeapSegments()) { + expected = bench.dotProductLuceneQuery(); + assertEquals(expected, bench.dotProductNativeQuery(), delta); + } + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Float32ScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.squareDistanceScalar(); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceNative(), delta); + + if (Float32ScorerBenchmark.supportsHeapSegments()) { + expected = bench.squareDistanceLuceneQuery(); + assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + } + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = Float32ScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index 4ed60b2f5e8b2..cec456d70afb7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -53,4 +54,32 @@ Optional getInt7SQVectorScorerSupplier( * @return an optional containing the vector scorer, or empty */ Optional getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector); + + /** + * Returns an optional containing a float32 vector scorer for the given + * parameters, or an empty optional if a scorer is not supported. + * + * @param similarityType the similarity type + * @param input the index input containing the vector data; + * offset of the first vector is 0, + * the length must be maxOrd * dims * Float#BYTES. + * @param values the random access vector values + * @return an optional containing the vector scorer, or empty + */ + Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ); + + /** + * Returns an optional containing a float32 vector scorer for the given + * parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity type + * @param values the random access vector values + * @param queryVector the query vector + * @return an optional containing the vector scorer, or empty + */ + Optional getFloat32VectorScorer(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 6248902c32e7a..f43bc1062e6c7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -39,4 +40,22 @@ public Optional getInt7SQVectorScorer( ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + throw new UnsupportedOperationException("should not reach here"); + } + + @Override + public Optional getFloat32VectorScorer( + VectorSimilarityFunction sim, + FloatVectorValues values, + float[] queryVector + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index a863d9e3448ca..7c45ac04c59b8 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -9,18 +9,17 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.elasticsearch.nativeaccess.NativeAccess; +import org.elasticsearch.simdvec.internal.Float32VectorScorer; +import org.elasticsearch.simdvec.internal.Float32VectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.DotProductSupplier; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier; +import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier; import java.util.Optional; @@ -41,17 +40,7 @@ public Optional getInt7SQVectorScorerSupplier( QuantizedByteVectorValues values, float scoreCorrectionConstant ) { - input = FilterIndexInput.unwrapOnlyTest(input); - if (input instanceof MemorySegmentAccessInput == false) { - return Optional.empty(); - } - MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; - checkInvariants(values.size(), values.dimension(), input); - return switch (similarityType) { - case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values, scoreCorrectionConstant)); - case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, scoreCorrectionConstant)); - case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values, scoreCorrectionConstant)); - }; + return Int7SQVectorScorerSupplier.create(similarityType, input, values, scoreCorrectionConstant); } @Override @@ -63,9 +52,28 @@ public Optional getInt7SQVectorScorer( return Int7SQVectorScorer.create(sim, values, queryVector); } - static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { - if (input.length() < (long) vectorByteLength * maxOrd) { - throw new IllegalArgumentException("input length is less than expected vector data"); - } + // -- floats + + @Override + public Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + return Float32VectorScorerSupplier.create(similarityType, input, values); + } + + @Override + public Optional getFloat32VectorScorer( + VectorSimilarityFunction sim, + FloatVectorValues values, + float[] queryVector + ) { + return Float32VectorScorer.create(sim, values, queryVector); + } + + @Override + public String toString() { + return "VectorScorerFactoryImpl"; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java new file mode 100644 index 0000000000000..89a0ea0df8eeb --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.util.Optional; + +public class Float32VectorScorer { + + // Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+ + public static Optional create(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector) { + return Optional.empty(); + } + + private Float32VectorScorer() {} +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java new file mode 100644 index 0000000000000..74f7b43ba41c8 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +public abstract sealed class Float32VectorScorerSupplier implements RandomVectorScorerSupplier { + + public static Optional create( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE -> Optional.of(new CosineSupplier(msInput, values)); + case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values)); + }; + } + + final int dims; + final int vectorByteSize; + final int maxOrd; + final MemorySegmentAccessInput input; + final FloatVectorValues values; // to support ordToDoc/getAcceptOrds + final VectorSimilarityFunction fallbackScorer; + + protected Float32VectorScorerSupplier( + MemorySegmentAccessInput input, + FloatVectorValues values, + VectorSimilarityFunction fallbackScorer + ) { + this.input = input; + this.values = values; + this.dims = values.dimension(); + this.vectorByteSize = values.getVectorByteLength(); + this.maxOrd = values.size(); + this.fallbackScorer = fallbackScorer; + } + + protected final void checkOrdinal(int ord) { + if (ord < 0 || ord > maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { + long firstByteOffset = (long) firstOrd * vectorByteSize; + long secondByteOffset = (long) secondOrd * vectorByteSize; + + MemorySegment firstSeg = input.segmentSliceOrNull(firstByteOffset, vectorByteSize); + if (firstSeg == null) { + return fallbackScore(firstByteOffset, secondByteOffset); + } + + MemorySegment secondSeg = input.segmentSliceOrNull(secondByteOffset, vectorByteSize); + if (secondSeg == null) { + return fallbackScore(firstByteOffset, secondByteOffset); + } + + return scoreFromSegments(firstSeg, secondSeg); + } + + abstract float scoreFromSegments(MemorySegment a, MemorySegment b); + + protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException { + float[] a = new float[dims]; + readFloats(a, firstByteOffset, dims); + + float[] b = new float[dims]; + readFloats(b, secondByteOffset, dims); + + return fallbackScorer.compare(a, b); + } + + final void readFloats(float[] floats, long byteOffset, int len) throws IOException { + for (int i = 0; i < len; i++) { + floats[i] = Float.intBitsToFloat(input.readInt(byteOffset)); + byteOffset += Float.BYTES; + } + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int ord = -1; + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + return scoreFromOrds(ord, node); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + checkOrdinal(node); + this.ord = node; + } + }; + } + + public static final class CosineSupplier extends Float32VectorScorerSupplier { + + public CosineSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.COSINE); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.cosineFloat32(a, b, dims); + return Math.max((1 + v) / 2, 0); + } + + @Override + public CosineSupplier copy() { + return new CosineSupplier(input.clone(), values); + } + } + + public static final class EuclideanSupplier extends Float32VectorScorerSupplier { + + public EuclideanSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.EUCLIDEAN); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.squareDistanceFloat32(a, b, dims); + return 1 / (1 + v); + } + + @Override + public EuclideanSupplier copy() { + return new EuclideanSupplier(input.clone(), values); + } + } + + public static final class DotProductSupplier extends Float32VectorScorerSupplier { + + public DotProductSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.DOT_PRODUCT); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float dotProduct = Similarities.dotProductFloat32(a, b, dims); + return Math.max((1 + dotProduct) / 2, 0); + } + + @Override + public DotProductSupplier copy() { + return new DotProductSupplier(input.clone(), values); + } + } + + public static final class MaxInnerProductSupplier extends Float32VectorScorerSupplier { + + public MaxInnerProductSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.dotProductFloat32(a, b, dims); + return VectorUtil.scaleMaxInnerProductScore(v); + } + + @Override + public MaxInnerProductSupplier copy() { + return new MaxInnerProductSupplier(input.clone(), values); + } + } + + static boolean checkIndex(long index, long length) { + return index >= 0 && index < length; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java index 19f33ba1c71f7..83ee1532472c2 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java @@ -9,14 +9,18 @@ package org.elasticsearch.simdvec.internal; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; +import org.elasticsearch.simdvec.VectorSimilarityType; import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.util.Optional; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; @@ -27,6 +31,25 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS static final byte BITS = 7; + public static Optional create( + VectorSimilarityType similarityType, + IndexInput input, + QuantizedByteVectorValues values, + float scoreCorrectionConstant + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values, scoreCorrectionConstant)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, scoreCorrectionConstant)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values, scoreCorrectionConstant)); + }; + } + final int dims; final int maxOrd; final float scoreCorrectionConstant; @@ -172,4 +195,10 @@ public MaxInnerProductSupplier copy() { static boolean checkIndex(long index, long length) { return index >= 0 && index < length; } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java index 482bbc8d8cabe..565b92892dbb3 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java @@ -23,6 +23,9 @@ public class Similarities { static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u(); static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u(); + static final MethodHandle COS_PRODUCT_FLOAT32 = DISTANCE_FUNCS.cosineHandleFloat32(); + static final MethodHandle DOT_PRODUCT_FLOAT32 = DISTANCE_FUNCS.dotProductHandleFloat32(); + static final MethodHandle SQR_PRODUCT_FLOAT32 = DISTANCE_FUNCS.squareDistanceHandleFloat32(); static int dotProduct7u(MemorySegment a, MemorySegment b, int length) { try { @@ -51,4 +54,46 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { } } } + + static float cosineFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) COS_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float dotProductFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) DOT_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float squareDistanceFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) SQR_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java new file mode 100644 index 0000000000000..3dd1a179dd6dd --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +public abstract sealed class Float32VectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int dims; + final int vectorByteSize; + final MemorySegmentAccessInput input; + final MemorySegment query; + byte[] scratch; + + /** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */ + public static Optional create(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector) { + checkDimensions(queryVector.length, values.dimension()); + IndexInput input = null; + if (values instanceof HasIndexSlice hasIndexSlice) { + input = hasIndexSlice.getSlice(); + } + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.getVectorByteLength(), input); + + return switch (sim) { + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, queryVector)); + }; + } + + Float32VectorScorer(MemorySegmentAccessInput input, FloatVectorValues values, float[] queryVector) { + super(values); + this.input = input; + assert queryVector.length * Float.BYTES == values.getVectorByteLength() && queryVector.length == values.dimension(); + this.vectorByteSize = values.getVectorByteLength(); + this.dims = values.dimension(); + this.query = MemorySegment.ofArray(queryVector); + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + public static final class CosineScorer extends Float32VectorScorer { + public CosineScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = Similarities.cosineFloat32(query, getSegment(node), dims); + return Math.max((1 + dotProduct) / 2, 0); + } + } + + public static final class DotProductScorer extends Float32VectorScorer { + public DotProductScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = Similarities.dotProductFloat32(query, getSegment(node), dims); + return Math.max((1 + dotProduct) / 2, 0); + } + } + + public static final class EuclideanScorer extends Float32VectorScorer { + public EuclideanScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float v = Similarities.squareDistanceFloat32(query, getSegment(node), dims); + return 1 / (1 + v); + } + } + + public static final class MaxInnerProductScorer extends Float32VectorScorer { + public MaxInnerProductScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float v = Similarities.dotProductFloat32(query, getSegment(node), dims); + return VectorUtil.scaleMaxInnerProductScore(v); + } + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java new file mode 100644 index 0000000000000..27740e4ed43f0 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java @@ -0,0 +1,409 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import java.util.stream.IntStream; + +import static org.elasticsearch.simdvec.VectorSimilarityType.COSINE; +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; +import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +public class Float32VectorScorerFactoryTests extends AbstractVectorTestCase { + + static final double BASE_DELTA = 1e-5; + + // Tests that the provider instance is present or not on expected platforms/architectures + public void testSupport() { + supported(); + } + + public void testSimple() throws IOException { + testSimpleImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleMaxChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(4, 16); + logger.info("maxChunkSize=" + maxChunkSize); + testSimpleImpl(maxChunkSize); + } + + void testSimpleImpl(long maxChunkSize) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + for (int dims : List.of(31, 32, 33)) { + // dimensions that cross the scalar / native boundary (stride) + float[] vector1 = new float[dims]; + float[] vector2 = new float[dims]; + String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + vector1[i] = (float) i; + vector2[i] = (float) (dims - i); + } + writeFloat32Vectors(out, vector1, vector2); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vector1, vector2); + + var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(); + luceneSupplier.setScoringOrdinal(0); + assertEquals(expected, luceneSupplier.score(1), BASE_DELTA); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertEquals(expected, scorer.score(1), BASE_DELTA); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vector1).get(); + assertEquals(expected, qScorer.score(1), BASE_DELTA); + } + } + } + } + } + } + + public void testNonNegative() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testNonNegativeDotProduct"), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) { + float[] vec1 = new float[32]; + float[] vec2 = new float[32]; + String fileName = "testNonNegativeDotProduct-32"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + vec1[0] = -2.0f; // values to trigger a negative dot product + vec2[0] = 1.0f; + writeFloat32Vectors(out, vec1, vec2); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(32, 2, in, VectorSimilarityType.of(DOT_PRODUCT)); + // dot product + float expected = 0.0f; + assertThat(luceneFloat32Score(DOT_PRODUCT, vec1, vec2), equalTo(expected)); + var supplier = factory.getFloat32VectorScorerSupplier(DOT_PRODUCT, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + // max inner product + expected = luceneFloat32Score(MAXIMUM_INNER_PRODUCT, vec1, vec2); + supplier = factory.getFloat32VectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + assertThat(scorer.score(1), equalTo(expected)); + // cosine + expected = 0f; + assertThat(luceneFloat32Score(COSINE, vec1, vec2), equalTo(expected)); + supplier = factory.getFloat32VectorScorerSupplier(COSINE, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + // euclidean + expected = luceneFloat32Score(EUCLIDEAN, vec1, vec2); + supplier = factory.getFloat32VectorScorerSupplier(EUCLIDEAN, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + } + } + } + + public void testRandom() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomWithChunkSize(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testRandomMaxChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomWithChunkSize(maxChunkSize); + } + + void testRandomWithChunkSize(long maxChunkSize) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final float[][] vectors = IntStream.range(0, size).mapToObj(i -> randomVector(dims)).toArray(float[][]::new); + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testRandom-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeFloat32Vectors(out, vectors); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vectors[idx0], vectors[idx1]); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + + public void testRandomSlice() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSliceImpl(30, 64, 1); + } + + void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + for (int times = 0; times < TIMES; times++) { + final int size = randomIntBetween(2, 100); + final float[][] vectors = IntStream.range(0, size).mapToObj(i -> randomVector(dims)).toArray(float[][]::new); + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testRandomSliceImpl-" + times + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialPadding]; + out.writeBytes(ba, 0, ba.length); + writeFloat32Vectors(out, vectors); + } + try ( + var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) + ) { + for (int itrs = 0; itrs < TIMES / 10; itrs++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vectors[idx0], vectors[idx1]); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + } + + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 4096; + final int size = 262144; + assert (long) dims * Float.BYTES * size > Integer.MAX_VALUE; + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testLarge-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + writeFloat32Vectors(out, vector(i, dims)); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vector(idx0, dims), vector(idx1, dims)); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vector(idx0, dims)).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + + public void testRace() throws Exception { + testRaceImpl(COSINE); + testRaceImpl(DOT_PRODUCT); + testRaceImpl(EUCLIDEAN); + testRaceImpl(MAXIMUM_INNER_PRODUCT); + } + + // Tests that copies in threads do not interfere with each other + void testRaceImpl(VectorSimilarityType sim) throws Exception { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final long maxChunkSize = 32; + final int dims = 34; // dimensions that are larger than the chunk size, to force fallback + float[] vec1 = new float[dims]; + float[] vec2 = new float[dims]; + IntStream.range(0, dims).forEach(i -> vec1[i] = 1); + IntStream.range(0, dims).forEach(i -> vec2[i] = 2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "testRace-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeFloat32Vectors(out, vec1, vec1, vec2, vec2); + } + var expectedScore1 = luceneFloat32Score(sim, vec1, vec1); + var expectedScore2 = luceneFloat32Score(sim, vec2, vec2); + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim)); + var scoreSupplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var tasks = List.>>of( + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) + ); + var executor = Executors.newFixedThreadPool(2); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L)); + for (var res : results) { + assertThat("Unexpected exception" + res.get(), res.get(), isEmpty()); + } + } + } + } + + static class ScoreCallable implements Callable> { + + final UpdateableRandomVectorScorer scorer; + final int ord; + final float expectedScore; + + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional call() { + try { + for (int i = 0; i < 100; i++) { + assertThat(scorer.score(ord), equalTo(expectedScore)); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + FloatVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + var slice = in.slice("values", 0, in.length()); + var byteSize = dims * Float.BYTES; + return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, slice, byteSize, DefaultFlatVectorScorer.INSTANCE, sim); + } + + /** Computes the score using the Lucene implementation. */ + static float luceneFloat32Score(VectorSimilarityType similarityFunc, float[] a, float[] b) { + return VectorSimilarityType.of(similarityFunc).compare(a, b); + } + + RandomVectorScorerSupplier luceneScoreSupplier(KnnVectorValues values, VectorSimilarityFunction sim) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorerSupplier(sim, values); + } + + static float[] randomVector(int dim) { + float[] v = new float[dim]; + Random random = random(); + for (int i = 0; i < dim; i++) { + v[i] = random.nextFloat(); + } + return v; + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static float[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + float[] fa = new float[dims]; + for (int i = 0; i < dims; i++) { + fa[i] = random.nextFloat(); + } + return fa; + } + + static void writeFloat32Vectors(IndexOutput out, float[]... vectors) throws IOException { + var buffer = ByteBuffer.allocate(vectors[0].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (var v : vectors) { + buffer.asFloatBuffer().put(v); + out.writeBytes(buffer.array(), buffer.array().length); + } + } + + static final int TIMES = 100; // a loop iteration times +} diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index ef3d5da1c9531..dc16d09e309bf 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -458,6 +458,7 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat, org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index 29f62b64764a9..206afbc857941 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -13,7 +13,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; @@ -41,7 +40,7 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES813FlatVectorFormat"; - private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); + private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(ESFlatVectorsScorer.create()); /** * Sole constructor diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java new file mode 100644 index 0000000000000..6c82d75348ffd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.util.Objects; + +/** + * Flat vectors scorer that wraps a given delegate. This scorer first checks the Elasticsearch + * simdvec scorer factory for an optimized implementation before falling back to the delegate + * if one does not exist. + * + *

The current implementation checks for optimized simdvec float32 scorers. + */ +public final class ESFlatVectorsScorer implements FlatVectorsScorer { + + static final VectorScorerFactory VECTOR_SCORER_FACTORY = VectorScorerFactory.instance().orElse(null); + + final FlatVectorsScorer delegate; + + /** Creates a FlatVectorsScorer delegating to Lucene's default scorer {@link FlatVectorScorerUtil#getLucene99FlatVectorsScorer()} + * if the platform does not have an optimized scorer factory. + */ + public static FlatVectorsScorer create() { + return create(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + } + + /** Creates a FlatVectorsScorer. Returns the delegate if the platform does not have an optimized scorer factory. */ + static FlatVectorsScorer create(FlatVectorsScorer delegate) { + Objects.requireNonNull(delegate); + if (VECTOR_SCORER_FACTORY == null) { + return delegate; + } + return new ESFlatVectorsScorer(delegate); + } + + private ESFlatVectorsScorer(FlatVectorsScorer delegate) { + this.delegate = delegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, KnnVectorValues values) + throws IOException { + assert VECTOR_SCORER_FACTORY != null; + if (values instanceof FloatVectorValues fValues && fValues instanceof HasIndexSlice sliceable) { + if (sliceable.getSlice() != null) { + var scorer = VECTOR_SCORER_FACTORY.getFloat32VectorScorerSupplier( + VectorSimilarityType.of(sim), + sliceable.getSlice(), + fValues + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return delegate.getRandomVectorScorerSupplier(sim, values); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, float[] query) + throws IOException { + assert VECTOR_SCORER_FACTORY != null; + if (values instanceof FloatVectorValues fValues && fValues instanceof HasIndexSlice sliceable) { + if (sliceable.getSlice() != null) { + var scorer = VECTOR_SCORER_FACTORY.getFloat32VectorScorer(sim, fValues, query); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return delegate.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, byte[] query) throws IOException { + assert VECTOR_SCORER_FACTORY != null; + return delegate.getRandomVectorScorer(sim, values, query); + } + + @Override + public String toString() { + return "ESFlatVectorsScorer(" + "delegate=" + delegate + ", factory=" + VECTOR_SCORER_FACTORY + ')'; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java new file mode 100644 index 0000000000000..67945567ac4b9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java @@ -0,0 +1,87 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es819; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.ESFlatVectorsScorer; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +// Minimal copy of Lucene99HnswVectorsFormat in order to provide an optimized scorer, +// which returns identical scores to that of the default flat vector scorer. +public class ES819HnswVectorsFormat extends KnnVectorsFormat { + + static final String NAME = "ES819HnswVectorsFormat"; + + static final int MAXIMUM_MAX_CONN = Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + static final int MAXIMUM_BEAM_WIDTH = Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; + + static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(ESFlatVectorsScorer.create()); + + private final int maxConn; + private final int beamWidth; + + public ES819HnswVectorsFormat() { + this(Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + } + + public ES819HnswVectorsFormat(int maxConn, int beamWidth) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public String toString() { + return "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 4d1c4fc41526c..4790221018849 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -57,6 +57,7 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -2119,7 +2120,7 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { if (elementType == ElementType.BIT) { return new ES815HnswBitVectorsFormat(m, efConstruction); } - return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); + return new ES819HnswVectorsFormat(m, efConstruction); } @Override diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 14e68029abc3b..b79f746f94d95 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,4 +7,5 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java new file mode 100644 index 0000000000000..78babb624dbc7 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es819; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.nativeaccess.NativeAccess; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES819HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final boolean optimizedScorer = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES819HnswVectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES819HnswVectorsFormat(10, 20); + } + }; + var expectedScorer = optimizedScorer + ? "ESFlatVectorsScorer(delegate=%luceneScorer%, factory=VectorScorerFactoryImpl)" + : "%luceneScorer%"; + String expectedPattern = "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%scorer%))".replace("%scorer%", expectedScorer); + + var defaultScorer = expectedPattern.replace("%luceneScorer%", "DefaultFlatVectorScorer()"); + var memSegScorer = expectedPattern.replace("%luceneScorer%", "Lucene99MemorySegmentFlatVectorsScorer()"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index f79c14c831f86..7a63d4ea1b1d0 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.search.lookup.Source; import org.elasticsearch.search.lookup.SourceProvider; import org.elasticsearch.search.vectors.VectorData; @@ -2628,6 +2629,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { ); } + static final boolean optimizedScorer = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + public void testKnnVectorsFormat() throws IOException { final int m = randomIntBetween(1, DEFAULT_MAX_CONN + 10); final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); @@ -2661,11 +2664,14 @@ public void testKnnVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=" + var expectedScorer = optimizedScorer + ? "ESFlatVectorsScorer(delegate=DefaultFlatVectorScorer(), factory=VectorScorerFactoryImpl)" + : "DefaultFlatVectorScorer()"; + String expectedString = "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=" + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH) - + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())" + + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%scorer%)".replace("%scorer%", expectedScorer) + ")"; assertEquals(expectedString, knnVectorsFormat.toString()); }