From 4682213cb735e62381759167b3aca48202fa72a7 Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Fri, 27 Jun 2025 15:11:05 +0100 Subject: [PATCH 1/6] Add optimized native Neon, AVX2, and AVX 512 float32 vector operations. --- .../vector/Float32ScorerBenchmark.java | 238 ++++++++++ ...nchmark.java => Int7uScorerBenchmark.java} | 44 +- .../vector/JDKVectorFloat32Benchmark.java | 220 ++++++++++ .../vector/JDKVectorInt7uBenchmark.java | 2 +- .../vector/Float32ScorerBenchmarkTests.java | 80 ++++ .../vector/Int7uScorerBenchmarkTests.java | 80 ++++ .../JDKVectorFloat32BenchmarkTests.java | 139 ++++++ libs/native/libraries/build.gradle | 2 +- .../VectorSimilarityFunctions.java | 30 ++ .../nativeaccess/jdk/JdkVectorLibrary.java | 127 +++++- .../VectorSimilarityFunctionsTests.java | 43 +- .../jdk/JDKVectorLibraryFloat32Tests.java | 225 ++++++++++ ...s.java => JDKVectorLibraryInt7uTests.java} | 38 +- libs/simdvec/native/src/vec/c/aarch64/vec.c | 167 +++++++ libs/simdvec/native/src/vec/c/amd64/vec.c | 159 +++++++ libs/simdvec/native/src/vec/c/amd64/vec_2.cpp | 146 +++++++ libs/simdvec/native/src/vec/headers/vec.h | 7 + .../simdvec/VectorScorerFactory.java | 29 ++ .../simdvec/VectorScorerFactoryImpl.java | 19 + .../simdvec/VectorScorerFactoryImpl.java | 48 +- .../simdvec/internal/Float32VectorScorer.java | 26 ++ .../internal/Float32VectorScorerSupplier.java | 209 +++++++++ .../internal/Int7SQVectorScorerSupplier.java | 29 ++ .../simdvec/internal/Similarities.java | 45 ++ .../simdvec/internal/Float32VectorScorer.java | 150 +++++++ .../simdvec/AbstractVectorTestCase.java | 15 +- .../Float32VectorScorerFactoryTests.java | 409 ++++++++++++++++++ ...va => Int7SQVectorScorerFactoryTests.java} | 23 +- server/src/main/java/module-info.java | 1 + .../codec/vectors/ES813FlatVectorFormat.java | 2 +- .../codec/vectors/ESFlatVectorsScorer.java | 103 +++++ .../vectors/es819/ES819HnswVectorsFormat.java | 88 ++++ .../vectors/DenseVectorFieldMapper.java | 3 +- .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + .../es819/ES819HnswVectorsFormatTests.java | 56 +++ 35 files changed, 2890 insertions(+), 113 deletions(-) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java rename benchmarks/src/main/java/org/elasticsearch/benchmark/vector/{VectorScorerBenchmark.java => Int7uScorerBenchmark.java} (86%) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java create mode 100644 benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java create mode 100644 benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java create mode 100644 benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java create mode 100644 libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java rename libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/{JDKVectorLibraryTests.java => JDKVectorLibraryInt7uTests.java} (85%) create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java create mode 100644 libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java rename libs/simdvec/src/test/java/org/elasticsearch/simdvec/{VectorScorerFactoryTests.java => Int7SQVectorScorerFactoryTests.java} (96%) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java new file mode 100644 index 0000000000000..dbde4e92492a1 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java @@ -0,0 +1,238 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; + +@Fork(value = 1, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) +@Warmup(iterations = 3, time = 3) +@Measurement(iterations = 5, time = 3) +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Thread) +/** + * Benchmark that compares various float32 vector similarity function + * implementations;: scalar, lucene's panama-ized, and Elasticsearch's native. + * Run with ./gradlew -p benchmarks run --args 'Float32ScorerBenchmark' + */ +public class Float32ScorerBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "96", "768", "1024" }) + public int dims; + final int size = 3; // there are only two vectors to compare against + + Directory dir; + IndexInput in; + VectorScorerFactory factory; + + float[] vec1, vec2, vec3; + + UpdateableRandomVectorScorer luceneDotScorer; + UpdateableRandomVectorScorer luceneSqrScorer; + UpdateableRandomVectorScorer nativeDotScorer; + UpdateableRandomVectorScorer nativeSqrScorer; + + RandomVectorScorer luceneDotScorerQuery; + RandomVectorScorer nativeDotScorerQuery; + RandomVectorScorer luceneSqrScorerQuery; + RandomVectorScorer nativeSqrScorerQuery; + + @Setup + public void setup() throws IOException { + var optionalVectorScorerFactory = VectorScorerFactory.instance(); + if (optionalVectorScorerFactory.isEmpty()) { + String msg = "JDK=[" + + Runtime.version() + + "], os.name=[" + + System.getProperty("os.name") + + "], os.arch=[" + + System.getProperty("os.arch") + + "]"; + throw new AssertionError("Vector scorer factory not present. Cannot run the benchmark. " + msg); + } + factory = optionalVectorScorerFactory.get(); + vec1 = randomFloatArray(dims); + vec2 = randomFloatArray(dims); + vec3 = randomFloatArray(dims); + + dir = new MMapDirectory(Files.createTempDirectory("nativeFloat32Bench")); + try (IndexOutput out = dir.createOutput("vector32.data", IOContext.DEFAULT)) { + writeFloat32Vectors(out, vec1, vec2, vec3); + } + in = dir.openInput("vector32.data", IOContext.DEFAULT); + var values = vectorValues(dims, 3, in, VectorSimilarityFunction.DOT_PRODUCT); + luceneDotScorer = luceneScoreSupplier(values, VectorSimilarityFunction.DOT_PRODUCT).scorer(); + luceneDotScorer.setScoringOrdinal(0); + values = vectorValues(dims, 3, in, VectorSimilarityFunction.EUCLIDEAN); + luceneSqrScorer = luceneScoreSupplier(values, VectorSimilarityFunction.EUCLIDEAN).scorer(); + luceneSqrScorer.setScoringOrdinal(0); + + nativeDotScorer = factory.getFloat32VectorScorerSupplier(DOT_PRODUCT, in, values).get().scorer(); + nativeDotScorer.setScoringOrdinal(0); + nativeSqrScorer = factory.getFloat32VectorScorerSupplier(EUCLIDEAN, in, values).get().scorer(); + nativeSqrScorer.setScoringOrdinal(0); + + // setup for getFloat32VectorScorer / query vector scoring + float[] queryVec = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVec[i] = ThreadLocalRandom.current().nextFloat(); + } + luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); + nativeDotScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); + luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); + nativeSqrScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); + } + + @TearDown + public void teardown() throws IOException { + IOUtils.close(dir, in); + } + + // we score against two different ords to avoid the lastOrd cache in vector values + @Benchmark + public float dotProductLucene() throws IOException { + return luceneDotScorer.score(1) + luceneDotScorer.score(2); + } + + @Benchmark + public float dotProductNative() throws IOException { + return nativeDotScorer.score(1) + nativeDotScorer.score(2); + } + + @Benchmark + public float dotProductScalar() { + return dotProductScalarImpl(vec1, vec2) + dotProductScalarImpl(vec1, vec3); + } + + @Benchmark + public float dotProductLuceneQuery() throws IOException { + return luceneDotScorerQuery.score(1) + luceneDotScorerQuery.score(2); + } + + @Benchmark + public float dotProductNativeQuery() throws IOException { + return nativeDotScorerQuery.score(1) + nativeDotScorerQuery.score(2); + } + + // -- square distance + + @Benchmark + public float squareDistanceLucene() throws IOException { + return luceneSqrScorer.score(1) + luceneSqrScorer.score(2); + } + + @Benchmark + public float squareDistanceNative() throws IOException { + return nativeSqrScorer.score(1) + nativeSqrScorer.score(2); + } + + @Benchmark + public float squareDistanceScalar() { + return squareDistanceScalarImpl(vec1, vec2) + squareDistanceScalarImpl(vec1, vec3); + } + + @Benchmark + public float squareDistanceLuceneQuery() throws IOException { + return luceneSqrScorerQuery.score(1) + luceneSqrScorerQuery.score(2); + } + + @Benchmark + public float squareDistanceNativeQuery() throws IOException { + return nativeSqrScorerQuery.score(1) + nativeSqrScorerQuery.score(2); + } + + static float dotProductScalarImpl(float[] vec1, float[] vec2) { + float dot = 0; + for (int i = 0; i < vec1.length; i++) { + dot += vec1[i] * vec2[i]; + } + return Math.max((1 + dot) / 2, 0); + } + + static float squareDistanceScalarImpl(float[] vec1, float[] vec2) { + float dst = 0; + for (int i = 0; i < vec1.length; i++) { + float diff = vec1[i] - vec2[i]; + dst += diff * diff; + } + return 1 / (1f + dst); + } + + FloatVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + var slice = in.slice("values", 0, in.length()); + var byteSize = dims * Float.BYTES; + return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, slice, byteSize, DefaultFlatVectorScorer.INSTANCE, sim); + } + + RandomVectorScorerSupplier luceneScoreSupplier(FloatVectorValues values, VectorSimilarityFunction sim) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorerSupplier(sim, values); + } + + RandomVectorScorer luceneScorer(FloatVectorValues values, VectorSimilarityFunction sim, float[] queryVec) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorer(sim, values, queryVec); + } + + static void writeFloat32Vectors(IndexOutput out, float[]... vectors) throws IOException { + var buffer = ByteBuffer.allocate(vectors[0].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (var v : vectors) { + buffer.asFloatBuffer().put(v); + out.writeBytes(buffer.array(), buffer.array().length); + } + } + + static float[] randomFloatArray(int length) { + var random = ThreadLocalRandom.current(); + float[] fa = new float[length]; + for (int i = 0; i < length; i++) { + fa[i] = random.nextFloat(); + } + return fa; + } +} diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java similarity index 86% rename from benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java rename to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java index f56bb8995b34e..9fc0b7dc199ac 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java @@ -55,24 +55,23 @@ /** * Benchmark that compares various scalar quantized vector similarity function * implementations;: scalar, lucene's panama-ized, and Elasticsearch's native. - * Run with ./gradlew -p benchmarks run --args 'VectorScorerBenchmark' + * Run with ./gradlew -p benchmarks run --args 'Int7uScorerBenchmark' */ -public class VectorScorerBenchmark { +public class Int7uScorerBenchmark { static { LogConfigurator.configureESLogging(); // native access requires logging to be initialized } @Param({ "96", "768", "1024" }) - int dims; - int size = 2; // there are only two vectors to compare + public int dims; + final int size = 2; // there are only two vectors to compare Directory dir; IndexInput in; VectorScorerFactory factory; - byte[] vec1; - byte[] vec2; + byte[] vec1, vec2; float vec1Offset; float vec2Offset; float scoreCorrectionConstant; @@ -139,39 +138,6 @@ public void setup() throws IOException { nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); - - // sanity - var f1 = dotProductLucene(); - var f2 = dotProductNative(); - var f3 = dotProductScalar(); - if (f1 != f2) { - throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]"); - } - if (f1 != f3) { - throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]"); - } - // square distance - f1 = squareDistanceLucene(); - f2 = squareDistanceNative(); - f3 = squareDistanceScalar(); - if (f1 != f2) { - throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]"); - } - if (f1 != f3) { - throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]"); - } - - var q1 = dotProductLuceneQuery(); - var q2 = dotProductNativeQuery(); - if (q1 != q2) { - throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); - } - - var sqr1 = squareDistanceLuceneQuery(); - var sqr2 = squareDistanceNativeQuery(); - if (sqr1 != sqr2) { - throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); - } } @TearDown diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java new file mode 100644 index 0000000000000..c8517c6a6c7fc --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32Benchmark.java @@ -0,0 +1,220 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.logging.NodeNamePatternConverter; +import org.elasticsearch.nativeaccess.NativeAccess; +import org.elasticsearch.nativeaccess.VectorSimilarityFunctions; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@State(Scope.Benchmark) +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +public class JDKVectorFloat32Benchmark { + + static { + NodeNamePatternConverter.setGlobalNodeName("foo"); + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + float[] floatsA; + float[] floatsB; + float[] scratch; + MemorySegment heapSegA, heapSegB; + MemorySegment nativeSegA, nativeSegB; + + Arena arena; + + @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) + public int size; + + @Setup(Level.Iteration) + public void init() { + ThreadLocalRandom random = ThreadLocalRandom.current(); + + floatsA = new float[size]; + floatsB = new float[size]; + scratch = new float[size]; + for (int i = 0; i < size; ++i) { + floatsA[i] = random.nextFloat(); + floatsB[i] = random.nextFloat(); + } + heapSegA = MemorySegment.ofArray(floatsA); + heapSegB = MemorySegment.ofArray(floatsB); + + arena = Arena.ofConfined(); + nativeSegA = arena.allocate((long) floatsA.length * Float.BYTES); + MemorySegment.copy(MemorySegment.ofArray(floatsA), LAYOUT_LE_FLOAT, 0L, nativeSegA, LAYOUT_LE_FLOAT, 0L, floatsA.length); + nativeSegB = arena.allocate((long) floatsB.length * Float.BYTES); + MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length); + } + + @TearDown + public void teardown() { + arena.close(); + } + + // -- cosine + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineLucene() { + return VectorUtil.cosine(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.cosine(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineNativeWithNativeSeg() { + return cosineFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float cosineNativeWithHeapSeg() { + return cosineFloat32(heapSegA, heapSegB, size); + } + + // -- dot product + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductLucene() { + return VectorUtil.dotProduct(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.dotProduct(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductNativeWithNativeSeg() { + return dotProductFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float dotProductNativeWithHeapSeg() { + return dotProductFloat32(heapSegA, heapSegB, size); + } + + // -- square distance + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceLucene() { + return VectorUtil.squareDistance(floatsA, floatsB); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceLuceneWithCopy() { + // add a copy to better reflect what Lucene has to do to get the target vector on-heap + MemorySegment.copy(nativeSegB, LAYOUT_LE_FLOAT, 0L, scratch, 0, scratch.length); + return VectorUtil.squareDistance(floatsA, scratch); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceNativeWithNativeSeg() { + return squareDistanceFloat32(nativeSegA, nativeSegB, size); + } + + @Benchmark + @Fork(value = 3, jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public float squareDistanceNativeWithHeapSeg() { + return squareDistanceFloat32(heapSegA, heapSegB, size); + } + + static final VectorSimilarityFunctions vectorSimilarityFunctions = vectorSimilarityFunctions(); + + static VectorSimilarityFunctions vectorSimilarityFunctions() { + return NativeAccess.instance().getVectorSimilarityFunctions().get(); + } + + float cosineFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) vectorSimilarityFunctions.squareDistanceHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } +} diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java index 41c2b3192cc92..e09f37ef24086 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java @@ -52,7 +52,7 @@ public class JDKVectorInt7uBenchmark { Arena arena; - @Param({ "1", "128", "207", "256", "300", "512", "702", "1024" }) + @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) public int size; @Setup(Level.Iteration) diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java new file mode 100644 index 0000000000000..ca262b36e41d3 --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class Float32ScorerBenchmarkTests extends ESTestCase { + + final double delta = 1e-3; + final int dims; + + public Float32ScorerBenchmarkTests(int dims) { + this.dims = dims; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + public void testDotProduct() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Float32ScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.dotProductScalar(); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductNative(), delta); + + expected = bench.dotProductLuceneQuery(); + assertEquals(expected, bench.dotProductNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Float32ScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.squareDistanceScalar(); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceNative(), delta); + + expected = bench.squareDistanceLuceneQuery(); + assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = Float32ScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } +} diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java new file mode 100644 index 0000000000000..b8cf689783d1d --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class Int7uScorerBenchmarkTests extends ESTestCase { + + final double delta = 1e-3; + final int dims; + + public Int7uScorerBenchmarkTests(int dims) { + this.dims = dims; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + public void testDotProduct() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Int7uScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.dotProductScalar(); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductNative(), delta); + + expected = bench.dotProductLuceneQuery(); + assertEquals(expected, bench.dotProductNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Int7uScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.squareDistanceScalar(); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceNative(), delta); + + expected = bench.squareDistanceLuceneQuery(); + assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = Int7uScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } +} diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java new file mode 100644 index 0000000000000..578f9701c9ccc --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/JDKVectorFloat32BenchmarkTests.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class JDKVectorFloat32BenchmarkTests extends ESTestCase { + + final double delta; + final int size; + + public JDKVectorFloat32BenchmarkTests(int size) { + this.size = size; + delta = 1e-3 * size; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } + + public void testCosine() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = cosineFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.cosineLucene(), delta); + assertEquals(expected, bench.cosineLuceneWithCopy(), delta); + assertEquals(expected, bench.cosineNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.cosineNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + public void testDotProduct() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = dotProductFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductLuceneWithCopy(), delta); + assertEquals(expected, bench.dotProductNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.dotProductNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() { + for (int i = 0; i < 100; i++) { + var bench = new JDKVectorFloat32Benchmark(); + bench.size = size; + bench.init(); + try { + float expected = squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceLuceneWithCopy(), delta); + assertEquals(expected, bench.squareDistanceNativeWithNativeSeg(), delta); + if (supportsHeapSegments()) { + assertEquals(expected, bench.squareDistanceNativeWithHeapSeg(), delta); + } + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = JDKVectorFloat32Benchmark.class.getField("size").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } + + /** Computes the cosine of the given vectors a and b. */ + static float cosineFloat32Scalar(float[] a, float[] b) { + float dot = 0, normA = 0, normB = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + double normAA = Math.sqrt(normA); + double normBB = Math.sqrt(normB); + if (normAA == 0.0f || normBB == 0.0f) return 0.0f; + return (float) (dot / (normAA * normBB)); + } + + /** Computes the dot product of the given vectors a and b. */ + static float dotProductFloat32Scalar(float[] a, float[] b) { + float res = 0; + for (int i = 0; i < a.length; i++) { + res += a[i] * b[i]; + } + return res; + } + + /** Computes the dot product of the given vectors a and b. */ + static float squareDistanceFloat32Scalar(float[] a, float[] b) { + float squareSum = 0; + for (int i = 0; i < a.length; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 58562ddcd6882..e00090e6df4c4 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -52,7 +52,7 @@ dependencies { libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64" libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64" libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64" - libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib +// libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib } def extractLibs = tasks.register('extractLibs', Copy) { diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java index 29a298b714fdd..4d3f6bc5b2c79 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java @@ -40,4 +40,34 @@ public interface VectorSimilarityFunctions { * vector data bytes. The third argument is the length of the vector data. */ MethodHandle squareDistanceHandle7u(); + + /** + * Produces a method handle returning the cosine of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle cosineHandleFloat32(); + + /** + * Produces a method handle returning the dot product of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle dotProductHandleFloat32(); + + /** + * Produces a method handle returning the square distance of float32 vectors. + * + *

The type of the method handle will have {@code float} as return type, The type of + * its first and second arguments will be {@code MemorySegment}, whose contents is the + * vector data floats. The third argument is the length of the vector data - number of + * 4-byte float32 elements. + */ + MethodHandle squareDistanceHandleFloat32(); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java index 2b56e65f39aae..2c429283d64ef 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java @@ -23,6 +23,7 @@ import java.util.Objects; import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_FLOAT; import static java.lang.foreign.ValueLayout.JAVA_INT; import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; @@ -32,8 +33,11 @@ public final class JdkVectorLibrary implements VectorLibrary { static final MethodHandle dot7u$mh; static final MethodHandle sqr7u$mh; + static final MethodHandle cosf32$mh; + static final MethodHandle dotf32$mh; + static final MethodHandle sqrf32$mh; - static final VectorSimilarityFunctions INSTANCE; + public static final JdkVectorSimilarityFunctions INSTANCE; static { LoaderHelper.loadLibrary("vec"); @@ -54,6 +58,21 @@ public final class JdkVectorLibrary implements VectorLibrary { FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), LinkerHelperUtil.critical() ); + cosf32$mh = downcallHandle( + "cosf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + dotf32$mh = downcallHandle( + "dotf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqrf32$mh = downcallHandle( + "sqrf32_2", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); } else { dot7u$mh = downcallHandle( "dot7u", @@ -65,6 +84,21 @@ public final class JdkVectorLibrary implements VectorLibrary { FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT), LinkerHelperUtil.critical() ); + cosf32$mh = downcallHandle( + "cosf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + dotf32$mh = downcallHandle( + "dotf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); + sqrf32$mh = downcallHandle( + "sqrf32", + FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT), + LinkerHelperUtil.critical() + ); } INSTANCE = new JdkVectorSimilarityFunctions(); } else { @@ -75,6 +109,9 @@ public final class JdkVectorLibrary implements VectorLibrary { } dot7u$mh = null; sqr7u$mh = null; + cosf32$mh = null; + dotf32$mh = null; + sqrf32$mh = null; INSTANCE = null; } } catch (Throwable t) { @@ -120,7 +157,46 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { return sqr7u(a, b, length); } - static void checkByteSize(MemorySegment a, MemorySegment b) { + /** + * Computes the cosine of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float cosineF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return cosf32(a, b, elementCount); + } + + /** + * Computes the dot product of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float dotProductF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return dotf32(a, b, elementCount); + } + + /** + * Computes the square distance of given float32 vectors. + * + * @param a address of the first vector + * @param b address of the second vector + * @param elementCount the vector dimensions, number of float32 elements in the segment + */ + static float squareDistanceF32(MemorySegment a, MemorySegment b, int elementCount) { + checkByteSize(a, b); + Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES); + return sqrf32(a, b, elementCount); + } + + private static void checkByteSize(MemorySegment a, MemorySegment b) { if (a.byteSize() != b.byteSize()) { throw new IllegalArgumentException("dimensions differ: " + a.byteSize() + "!=" + b.byteSize()); } @@ -142,8 +218,35 @@ private static int sqr7u(MemorySegment a, MemorySegment b, int length) { } } + private static float cosf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.cosf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static float dotf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.dotf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + private static float sqrf32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) JdkVectorLibrary.sqrf32$mh.invokeExact(a, b, length); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + static final MethodHandle DOT_HANDLE_7U; static final MethodHandle SQR_HANDLE_7U; + static final MethodHandle COS_HANDLE_FLOAT32; + static final MethodHandle DOT_HANDLE_FLOAT32; + static final MethodHandle SQR_HANDLE_FLOAT32; static { try { @@ -151,6 +254,11 @@ private static int sqr7u(MemorySegment a, MemorySegment b, int length) { var mt = MethodType.methodType(int.class, MemorySegment.class, MemorySegment.class, int.class); DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt); SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt); + + mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class); + COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt); + DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt); + SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", mt); } catch (NoSuchMethodException | IllegalAccessException e) { throw new RuntimeException(e); } @@ -165,5 +273,20 @@ public MethodHandle dotProductHandle7u() { public MethodHandle squareDistanceHandle7u() { return SQR_HANDLE_7U; } + + @Override + public MethodHandle cosineHandleFloat32() { + return COS_HANDLE_FLOAT32; + } + + @Override + public MethodHandle dotProductHandleFloat32() { + return DOT_HANDLE_FLOAT32; + } + + @Override + public MethodHandle squareDistanceHandleFloat32() { + return SQR_HANDLE_FLOAT32; + } } } diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java index 53fd6c7f1fa6b..3d8433bf36487 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java @@ -9,22 +9,54 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.logging.NodeNamePatternConverter; import org.elasticsearch.test.ESTestCase; +import java.lang.foreign.Arena; +import java.util.Arrays; import java.util.Optional; +import java.util.stream.IntStream; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.hamcrest.Matchers.not; -public class VectorSimilarityFunctionsTests extends ESTestCase { +public abstract class VectorSimilarityFunctionsTests extends ESTestCase { - final Optional vectorSimilarityFunctions; + static { + NodeNamePatternConverter.setGlobalNodeName("foo"); + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + public static final Class IAE = IllegalArgumentException.class; + public static final Class IOOBE = IndexOutOfBoundsException.class; + + protected static Arena arena; + + protected final int size; + protected final Optional vectorSimilarityFunctions; + + protected static Iterable parametersFactory() { + var dims1 = Arrays.stream(new int[] { 1, 2, 4, 6, 8, 12, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 768 }); + var dims2 = Arrays.stream(new int[] { 1000, 1023, 1024, 1025, 2047, 2048, 2049, 4095, 4096, 4097 }); + return () -> IntStream.concat(dims1, dims2).boxed().map(i -> new Object[] { i }).iterator(); + } - public VectorSimilarityFunctionsTests() { + protected VectorSimilarityFunctionsTests(int size) { logger.info(platformMsg()); + this.size = size; vectorSimilarityFunctions = NativeAccess.instance().getVectorSimilarityFunctions(); } + public static void setup() { + arena = Arena.ofConfined(); + } + + public static void cleanup() { + arena.close(); + } + public void testSupported() { supported(); } @@ -59,4 +91,9 @@ public static String platformMsg() { var osName = System.getProperty("os.name"); return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch; } + + // Support for passing on-heap arrays/segments to native + protected static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } } diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java new file mode 100644 index 0000000000000..37473a85a6b96 --- /dev/null +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java @@ -0,0 +1,225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.nativeaccess.jdk; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.ByteOrder; +import java.util.function.IntFunction; + +import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED; +import static org.hamcrest.Matchers.containsString; + +public class JDKVectorLibraryFloat32Tests extends VectorSimilarityFunctionsTests { + + static final ValueLayout.OfFloat LAYOUT_LE_FLOAT = ValueLayout.JAVA_FLOAT_UNALIGNED.withOrder(ByteOrder.LITTLE_ENDIAN); + + final double delta; + + public JDKVectorLibraryFloat32Tests(int size) { + super(size); + this.delta = 1e-5 * size; // scale the delta with the size + } + + @BeforeClass + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); + } + + @AfterClass + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); + } + + @ParametersFactory + public static Iterable parametersFactory() { + return VectorSimilarityFunctionsTests.parametersFactory(); + } + + public void testAllZeroValues() { + testFloat32Impl(float[]::new); + } + + public void testRandomFloats() { + testFloat32Impl(JDKVectorLibraryFloat32Tests::randomFloatArray); + } + + public void testFloat32Impl(IntFunction vectorGeneratorFunc) { + assumeTrue(notSupportedMsg(), supported()); + final int dims = size; + final int numVecs = randomIntBetween(2, 101); + var values = new float[numVecs][dims]; + var segment = arena.allocate((long) dims * numVecs * Float.BYTES); + for (int i = 0; i < numVecs; i++) { + values[i] = vectorGeneratorFunc.apply(dims); + long dstOffset = (long) i * dims * Float.BYTES; + MemorySegment.copy(MemorySegment.ofArray(values[i]), JAVA_FLOAT_UNALIGNED, 0L, segment, LAYOUT_LE_FLOAT, dstOffset, dims); + } + + final int loopTimes = 1000; + for (int i = 0; i < loopTimes; i++) { + int first = randomInt(numVecs - 1); + int second = randomInt(numVecs - 1); + var nativeSeg1 = segment.asSlice((long) first * dims * Float.BYTES, (long) dims * Float.BYTES); + var nativeSeg2 = segment.asSlice((long) second * dims * Float.BYTES, (long) dims * Float.BYTES); + + // cosine + float expected = cosineFloat32Scalar(values[first], values[second]); + assertEquals(expected, cosineFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, cosineFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, cosineFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, cosineFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + // dot product + expected = dotProductFloat32Scalar(values[first], values[second]); + assertEquals(expected, dotProductFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, dotProductFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, dotProductFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, dotProductFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + // square distance + expected = squareDistanceFloat32Scalar(values[first], values[second]); + assertEquals(expected, squareDistanceFloat32(nativeSeg1, nativeSeg2, dims), delta); + if (supportsHeapSegments()) { + var heapSeg1 = MemorySegment.ofArray(values[first]); + var heapSeg2 = MemorySegment.ofArray(values[second]); + assertEquals(expected, squareDistanceFloat32(heapSeg1, heapSeg2, dims), delta); + assertEquals(expected, squareDistanceFloat32(nativeSeg1, heapSeg2, dims), delta); + assertEquals(expected, squareDistanceFloat32(heapSeg1, nativeSeg2, dims), delta); + } + + } + } + + public void testIllegalDims() { + assumeTrue(notSupportedMsg(), supported()); + var segment = arena.allocate((long) size * 3 * Float.BYTES); + + var e1 = expectThrows(IAE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + e1 = expectThrows(IAE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + e1 = expectThrows(IAE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size + 1), size)); + assertThat(e1.getMessage(), containsString("dimensions differ")); + + var e2 = expectThrows(IOOBE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), size + 1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + + e2 = expectThrows(IOOBE, () -> cosineFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> dotProductFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + e2 = expectThrows(IOOBE, () -> squareDistanceFloat32(segment.asSlice(0L, size), segment.asSlice(size, size), -1)); + assertThat(e2.getMessage(), containsString("out of bounds for length")); + } + + float cosineFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().cosineHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float dotProductFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().dotProductHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + float squareDistanceFloat32(MemorySegment a, MemorySegment b, int length) { + try { + return (float) getVectorDistance().squareDistanceHandleFloat32().invokeExact(a, b, length); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float[] randomFloatArray(int length) { + float[] fa = new float[length]; + for (int i = 0; i < length; i++) { + fa[i] = randomFloat(); + } + return fa; + } + + /** Computes the cosine of the given vectors a and b. */ + static float cosineFloat32Scalar(float[] a, float[] b) { + float dot = 0, normA = 0, normB = 0; + for (int i = 0; i < a.length; i++) { + dot += a[i] * b[i]; + normA += a[i] * a[i]; + normB += b[i] * b[i]; + } + double normAA = Math.sqrt(normA); + double normBB = Math.sqrt(normB); + if (normAA == 0.0f || normBB == 0.0f) { + return 0.0f; + } + return (float) (dot / (normAA * normBB)); + } + + /** Computes the dot product of the given vectors a and b. */ + static float dotProductFloat32Scalar(float[] a, float[] b) { + float res = 0; + for (int i = 0; i < a.length; i++) { + res += a[i] * b[i]; + } + return res; + } + + /** Computes the dot product of the given vectors a and b. */ + static float squareDistanceFloat32Scalar(float[] a, float[] b) { + float squareSum = 0; + for (int i = 0; i < a.length; i++) { + float diff = a[i] - b[i]; + squareSum += diff * diff; + } + return squareSum; + } +} diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java similarity index 85% rename from libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java rename to libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java index 04f80ba72891f..effad86d74a3e 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java @@ -15,47 +15,33 @@ import org.junit.AfterClass; import org.junit.BeforeClass; -import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; -import java.util.stream.IntStream; import static org.hamcrest.Matchers.containsString; -public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests { +public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests { // bounds of the range of values that can be seen by int7 scalar quantized vectors static final byte MIN_INT7_VALUE = 0; static final byte MAX_INT7_VALUE = 127; - static final Class IAE = IllegalArgumentException.class; - static final Class IOOBE = IndexOutOfBoundsException.class; - - static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 }; - - final int size; - - static Arena arena; - - final double delta; - - public JDKVectorLibraryTests(int size) { - this.size = size; - this.delta = 1e-5 * size; // scale the delta with the size + public JDKVectorLibraryInt7uTests(int size) { + super(size); } @BeforeClass - public static void setup() { - arena = Arena.ofConfined(); + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); } @AfterClass - public static void cleanup() { - arena.close(); + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); } @ParametersFactory public static Iterable parametersFactory() { - return () -> IntStream.of(VECTOR_DIMS).boxed().map(i -> new Object[] { i }).iterator(); + return VectorSimilarityFunctionsTests.parametersFactory(); } public void testInt7BinaryVectors() { @@ -79,7 +65,7 @@ public void testInt7BinaryVectors() { // dot product int expected = dotProductScalar(values[first], values[second]); assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims)); - if (testWithHeapSegments()) { + if (supportsHeapSegments()) { var heapSeg1 = MemorySegment.ofArray(values[first]); var heapSeg2 = MemorySegment.ofArray(values[second]); assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims)); @@ -90,7 +76,7 @@ public void testInt7BinaryVectors() { // square distance expected = squareDistanceScalar(values[first], values[second]); assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims)); - if (testWithHeapSegments()) { + if (supportsHeapSegments()) { var heapSeg1 = MemorySegment.ofArray(values[first]); var heapSeg2 = MemorySegment.ofArray(values[second]); assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims)); @@ -100,10 +86,6 @@ public void testInt7BinaryVectors() { } } - static boolean testWithHeapSegments() { - return Runtime.version().feature() >= 22; - } - public void testIllegalDims() { assumeTrue(notSupportedMsg(), supported()); var segment = arena.allocate((long) size * 3); diff --git a/libs/simdvec/native/src/vec/c/aarch64/vec.c b/libs/simdvec/native/src/vec/c/aarch64/vec.c index 59c0cdb2ff8ff..f3eb7f51ee5d1 100644 --- a/libs/simdvec/native/src/vec/c/aarch64/vec.c +++ b/libs/simdvec/native/src/vec/c/aarch64/vec.c @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #ifndef DOT7U_STRIDE_BYTES_LEN @@ -132,3 +133,169 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } + +// --- single precision floats + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float dotf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + float32x4_t sum4 = vdupq_n_f32(0.0f); + float32x4_t sum5 = vdupq_n_f32(0.0f); + float32x4_t sum6 = vdupq_n_f32(0.0f); + float32x4_t sum7 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + sum0 = vfmaq_f32(sum0, vld1q_f32(a + i), vld1q_f32(b + i)); + sum1 = vfmaq_f32(sum1, vld1q_f32(a + i + 4), vld1q_f32(b + i + 4)); + sum2 = vfmaq_f32(sum2, vld1q_f32(a + i + 8), vld1q_f32(b + i + 8)); + sum3 = vfmaq_f32(sum3, vld1q_f32(a + i + 12), vld1q_f32(b + i + 12)); + sum4 = vfmaq_f32(sum4, vld1q_f32(a + i + 16), vld1q_f32(b + i + 16)); + sum5 = vfmaq_f32(sum5, vld1q_f32(a + i + 20), vld1q_f32(b + i + 20)); + sum6 = vfmaq_f32(sum6, vld1q_f32(a + i + 24), vld1q_f32(b + i + 24)); + sum7 = vfmaq_f32(sum7, vld1q_f32(a + i + 28), vld1q_f32(b + i + 28)); + } + + float32x4_t total = vaddq_f32( + vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)), + vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7)) + ); + float result = vaddvq_f32(total); + + // Handle remaining elements + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float cosf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + + float32x4_t norm_a0 = vdupq_n_f32(0.0f); + float32x4_t norm_a1 = vdupq_n_f32(0.0f); + float32x4_t norm_a2 = vdupq_n_f32(0.0f); + float32x4_t norm_a3 = vdupq_n_f32(0.0f); + + float32x4_t norm_b0 = vdupq_n_f32(0.0f); + float32x4_t norm_b1 = vdupq_n_f32(0.0f); + float32x4_t norm_b2 = vdupq_n_f32(0.0f); + float32x4_t norm_b3 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 4x = 16 floats per loop + size_t unrolled_limit = elementCount & ~15UL; + for (; i < unrolled_limit; i += 16) { + float32x4_t va0 = vld1q_f32(a + i); + float32x4_t vb0 = vld1q_f32(b + i); + float32x4_t va1 = vld1q_f32(a + i + 4); + float32x4_t vb1 = vld1q_f32(b + i + 4); + float32x4_t va2 = vld1q_f32(a + i + 8); + float32x4_t vb2 = vld1q_f32(b + i + 8); + float32x4_t va3 = vld1q_f32(a + i + 12); + float32x4_t vb3 = vld1q_f32(b + i + 12); + + // Dot products + sum0 = vfmaq_f32(sum0, va0, vb0); + sum1 = vfmaq_f32(sum1, va1, vb1); + sum2 = vfmaq_f32(sum2, va2, vb2); + sum3 = vfmaq_f32(sum3, va3, vb3); + + // Norms + norm_a0 = vfmaq_f32(norm_a0, va0, va0); + norm_a1 = vfmaq_f32(norm_a1, va1, va1); + norm_a2 = vfmaq_f32(norm_a2, va2, va2); + norm_a3 = vfmaq_f32(norm_a3, va3, va3); + + norm_b0 = vfmaq_f32(norm_b0, vb0, vb0); + norm_b1 = vfmaq_f32(norm_b1, vb1, vb1); + norm_b2 = vfmaq_f32(norm_b2, vb2, vb2); + norm_b3 = vfmaq_f32(norm_b3, vb3, vb3); + } + + // Combine accumulators + float32x4_t sums = vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); + float32x4_t norms_a = vaddq_f32(vaddq_f32(norm_a0, norm_a1), vaddq_f32(norm_a2, norm_a3)); + float32x4_t norms_b = vaddq_f32(vaddq_f32(norm_b0, norm_b1), vaddq_f32(norm_b2, norm_b3)); + + float dot = vaddvq_f32(sums); + float norm_a = vaddvq_f32(norms_a); + float norm_b = vaddvq_f32(norms_b); + + // Handle remaining tail elements + for (; i < elementCount; ++i) { + float va = a[i]; + float vb = b[i]; + dot += va * vb; + norm_a += va * va; + norm_b += vb * vb; + } + + float denom = sqrtf(norm_a) * sqrtf(norm_b); + if (denom == 0.0f) { + return 0.0f; + } + return dot / denom; +} + +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { + float32x4_t sum0 = vdupq_n_f32(0.0f); + float32x4_t sum1 = vdupq_n_f32(0.0f); + float32x4_t sum2 = vdupq_n_f32(0.0f); + float32x4_t sum3 = vdupq_n_f32(0.0f); + float32x4_t sum4 = vdupq_n_f32(0.0f); + float32x4_t sum5 = vdupq_n_f32(0.0f); + float32x4_t sum6 = vdupq_n_f32(0.0f); + float32x4_t sum7 = vdupq_n_f32(0.0f); + + size_t i = 0; + // Each float32x4_t holds 4 floats, so unroll 8x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + float32x4_t d0 = vsubq_f32(vld1q_f32(a + i), vld1q_f32(b + i)); + float32x4_t d1 = vsubq_f32(vld1q_f32(a + i + 4), vld1q_f32(b + i + 4)); + float32x4_t d2 = vsubq_f32(vld1q_f32(a + i + 8), vld1q_f32(b + i + 8)); + float32x4_t d3 = vsubq_f32(vld1q_f32(a + i + 12), vld1q_f32(b + i + 12)); + float32x4_t d4 = vsubq_f32(vld1q_f32(a + i + 16), vld1q_f32(b + i + 16)); + float32x4_t d5 = vsubq_f32(vld1q_f32(a + i + 20), vld1q_f32(b + i + 20)); + float32x4_t d6 = vsubq_f32(vld1q_f32(a + i + 24), vld1q_f32(b + i + 24)); + float32x4_t d7 = vsubq_f32(vld1q_f32(a + i + 28), vld1q_f32(b + i + 28)); + + sum0 = vmlaq_f32(sum0, d0, d0); + sum1 = vmlaq_f32(sum1, d1, d1); + sum2 = vmlaq_f32(sum2, d2, d2); + sum3 = vmlaq_f32(sum3, d3, d3); + sum4 = vmlaq_f32(sum4, d4, d4); + sum5 = vmlaq_f32(sum5, d5, d5); + sum6 = vmlaq_f32(sum6, d6, d6); + sum7 = vmlaq_f32(sum7, d7, d7); + } + + float32x4_t total = vaddq_f32( + vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)), + vaddq_f32(vaddq_f32(sum4, sum5), vaddq_f32(sum6, sum7)) + ); + float result = vaddvq_f32(total); + + // Handle remaining tail elements + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index f63a7649b1390..24a648747541c 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #include @@ -187,3 +188,161 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { } return res; } + +// --- single precision floats + +// Horizontal add of all 8 elements in a __m256 register +static inline float horizontal_sum_avx2(__m256 v) { + // First, add the low and high 128-bit lanes + __m128 low = _mm256_castps256_ps128(v); // lower 128 bits + __m128 high = _mm256_extractf128_ps(v, 1); // upper 128 bits + __m128 sum128 = _mm_add_ps(low, high); // sum 8 floats → 4 floats + + // Then do horizontal sum within 128-bit lane + __m128 shuf = _mm_movehdup_ps(sum128); // duplicate odd-index elements + __m128 sums = _mm_add_ps(sum128, shuf); // add pairs + + shuf = _mm_movehl_ps(shuf, sums); // move high pair to low + sums = _mm_add_ss(sums, shuf); // add final two elements + + return _mm_cvtss_f32(sums); +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float cosf32(const float *a, const float *b, size_t elementCount) { + __m256 dot0 = _mm256_setzero_ps(); + __m256 dot1 = _mm256_setzero_ps(); + __m256 dot2 = _mm256_setzero_ps(); + __m256 dot3 = _mm256_setzero_ps(); + + __m256 norm_a0 = _mm256_setzero_ps(); + __m256 norm_a1 = _mm256_setzero_ps(); + __m256 norm_a2 = _mm256_setzero_ps(); + __m256 norm_a3 = _mm256_setzero_ps(); + + __m256 norm_b0 = _mm256_setzero_ps(); + __m256 norm_b1 = _mm256_setzero_ps(); + __m256 norm_b2 = _mm256_setzero_ps(); + __m256 norm_b3 = _mm256_setzero_ps(); + + size_t i = 0; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + __m256 a0 = _mm256_loadu_ps(a + i); + __m256 b0 = _mm256_loadu_ps(b + i); + __m256 a1 = _mm256_loadu_ps(a + i + 8); + __m256 b1 = _mm256_loadu_ps(b + i + 8); + __m256 a2 = _mm256_loadu_ps(a + i + 16); + __m256 b2 = _mm256_loadu_ps(b + i + 16); + __m256 a3 = _mm256_loadu_ps(a + i + 24); + __m256 b3 = _mm256_loadu_ps(b + i + 24); + + dot0 = _mm256_fmadd_ps(a0, b0, dot0); + dot1 = _mm256_fmadd_ps(a1, b1, dot1); + dot2 = _mm256_fmadd_ps(a2, b2, dot2); + dot3 = _mm256_fmadd_ps(a3, b3, dot3); + + norm_a0 = _mm256_fmadd_ps(a0, a0, norm_a0); + norm_a1 = _mm256_fmadd_ps(a1, a1, norm_a1); + norm_a2 = _mm256_fmadd_ps(a2, a2, norm_a2); + norm_a3 = _mm256_fmadd_ps(a3, a3, norm_a3); + + norm_b0 = _mm256_fmadd_ps(b0, b0, norm_b0); + norm_b1 = _mm256_fmadd_ps(b1, b1, norm_b1); + norm_b2 = _mm256_fmadd_ps(b2, b2, norm_b2); + norm_b3 = _mm256_fmadd_ps(b3, b3, norm_b3); + } + + // combine and reduce vector accumulators + __m256 dot_total = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3)); + __m256 norm_a_total = _mm256_add_ps(_mm256_add_ps(norm_a0, norm_a1), _mm256_add_ps(norm_a2, norm_a3)); + __m256 norm_b_total = _mm256_add_ps(_mm256_add_ps(norm_b0, norm_b1), _mm256_add_ps(norm_b2, norm_b3)); + + float dot_result = horizontal_sum_avx2(dot_total); + float norm_a_result = horizontal_sum_avx2(norm_a_total); + float norm_b_result = horizontal_sum_avx2(norm_b_total); + + // Handle remaining tail with scalar loop + for (; i < elementCount; ++i) { + float ai = a[i]; + float bi = b[i]; + dot_result += ai * bi; + norm_a_result += ai * ai; + norm_b_result += bi * bi; + } + + float denom = sqrtf(norm_a_result) * sqrtf(norm_b_result); + if (denom == 0.0f) { + return 0.0f; + } + return dot_result / denom; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float dotf32(const float *a, const float *b, size_t elementCount) { + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + size_t i = 0; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + size_t unrolled_limit = elementCount & ~31UL; + for (; i < unrolled_limit; i += 32) { + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i), acc0); + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 8), _mm256_loadu_ps(b + i + 8), acc1); + acc2 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 16), _mm256_loadu_ps(b + i + 16), acc2); + acc3 = _mm256_fmadd_ps(_mm256_loadu_ps(a + i + 24), _mm256_loadu_ps(b + i + 24), acc3); + } + + // Combine all partial sums + __m256 total_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3)); + float result = horizontal_sum_avx2(total_sum); + + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { + __m256 sum0 = _mm256_setzero_ps(); + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + __m256 sum3 = _mm256_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~31UL; + // Each __m256 holds 8 floats, so unroll 4x = 32 floats per loop + for (; i < unrolled_limit; i += 32) { + __m256 d0 = _mm256_sub_ps(_mm256_loadu_ps(a + i), _mm256_loadu_ps(b + i)); + __m256 d1 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 8), _mm256_loadu_ps(b + i + 8)); + __m256 d2 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 16), _mm256_loadu_ps(b + i + 16)); + __m256 d3 = _mm256_sub_ps(_mm256_loadu_ps(a + i + 24), _mm256_loadu_ps(b + i + 24)); + + sum0 = _mm256_fmadd_ps(d0, d0, sum0); + sum1 = _mm256_fmadd_ps(d1, d1, sum1); + sum2 = _mm256_fmadd_ps(d2, d2, sum2); + sum3 = _mm256_fmadd_ps(d3, d3, sum3); + } + + // reduce all partial sums + __m256 total_sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3)); + float result = horizontal_sum_avx2(total_sum); + + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} diff --git a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp index f851b2a13a9ea..dd062f8210c3c 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp +++ b/libs/simdvec/native/src/vec/c/amd64/vec_2.cpp @@ -9,6 +9,7 @@ #include #include +#include #include "vec.h" #ifdef _MSC_VER @@ -195,6 +196,151 @@ EXPORT int32_t sqr7u_2(int8_t* a, int8_t* b, size_t dims) { return res; } +// --- single precision floats + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float cosf32_2(const float *a, const float *b, size_t elementCount) { + __m512 dot0 = _mm512_setzero_ps(); + __m512 dot1 = _mm512_setzero_ps(); + __m512 dot2 = _mm512_setzero_ps(); + __m512 dot3 = _mm512_setzero_ps(); + + __m512 norm_a0 = _mm512_setzero_ps(); + __m512 norm_a1 = _mm512_setzero_ps(); + __m512 norm_a2 = _mm512_setzero_ps(); + __m512 norm_a3 = _mm512_setzero_ps(); + + __m512 norm_b0 = _mm512_setzero_ps(); + __m512 norm_b1 = _mm512_setzero_ps(); + __m512 norm_b2 = _mm512_setzero_ps(); + __m512 norm_b3 = _mm512_setzero_ps(); + + size_t i = 0; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + size_t unrolled_limit = elementCount & ~63UL; + for (; i < unrolled_limit; i += 64) { + // Load and compute 4 blocks of 16 elements + __m512 a0 = _mm512_loadu_ps(a + i); + __m512 b0 = _mm512_loadu_ps(b + i); + __m512 a1 = _mm512_loadu_ps(a + i + 16); + __m512 b1 = _mm512_loadu_ps(b + i + 16); + __m512 a2 = _mm512_loadu_ps(a + i + 32); + __m512 b2 = _mm512_loadu_ps(b + i + 32); + __m512 a3 = _mm512_loadu_ps(a + i + 48); + __m512 b3 = _mm512_loadu_ps(b + i + 48); + + dot0 = _mm512_fmadd_ps(a0, b0, dot0); + dot1 = _mm512_fmadd_ps(a1, b1, dot1); + dot2 = _mm512_fmadd_ps(a2, b2, dot2); + dot3 = _mm512_fmadd_ps(a3, b3, dot3); + + norm_a0 = _mm512_fmadd_ps(a0, a0, norm_a0); + norm_a1 = _mm512_fmadd_ps(a1, a1, norm_a1); + norm_a2 = _mm512_fmadd_ps(a2, a2, norm_a2); + norm_a3 = _mm512_fmadd_ps(a3, a3, norm_a3); + + norm_b0 = _mm512_fmadd_ps(b0, b0, norm_b0); + norm_b1 = _mm512_fmadd_ps(b1, b1, norm_b1); + norm_b2 = _mm512_fmadd_ps(b2, b2, norm_b2); + norm_b3 = _mm512_fmadd_ps(b3, b3, norm_b3); + } + + // combine and reduce vector accumulators + __m512 dot_total = _mm512_add_ps(_mm512_add_ps(dot0, dot1), _mm512_add_ps(dot2, dot3)); + __m512 norm_a_total = _mm512_add_ps(_mm512_add_ps(norm_a0, norm_a1), _mm512_add_ps(norm_a2, norm_a3)); + __m512 norm_b_total = _mm512_add_ps(_mm512_add_ps(norm_b0, norm_b1), _mm512_add_ps(norm_b2, norm_b3)); + + float dot_result = _mm512_reduce_add_ps(dot_total); + float norm_a_result = _mm512_reduce_add_ps(norm_a_total); + float norm_b_result = _mm512_reduce_add_ps(norm_b_total); + + // Handle remaining tail with scalar loop + for (; i < elementCount; ++i) { + float ai = a[i]; + float bi = b[i]; + dot_result += ai * bi; + norm_a_result += ai * ai; + norm_b_result += bi * bi; + } + + float denom = sqrtf(norm_a_result) * sqrtf(norm_b_result); + if (denom == 0.0f) { + return 0.0f; + } + return dot_result / denom; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float dotf32_2(const float *a, const float *b, size_t elementCount) { + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~63UL; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + for (; i < unrolled_limit; i += 64) { + sum0 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i), _mm512_loadu_ps(b + i), sum0); + sum1 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 16), _mm512_loadu_ps(b + i + 16), sum1); + sum2 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 32), _mm512_loadu_ps(b + i + 32), sum2); + sum3 = _mm512_fmadd_ps(_mm512_loadu_ps(a + i + 48), _mm512_loadu_ps(b + i + 48), sum3); + } + + // reduce all partial sums + __m512 total_sum = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3)); + float result = _mm512_reduce_add_ps(total_sum); + + for (; i < elementCount; ++i) { + result += a[i] * b[i]; + } + + return result; +} + +// const float *a pointer to the first float vector +// const float *b pointer to the second float vector +// size_t elementCount the number of floating point elements +extern "C" +EXPORT float sqrf32_2(const float *a, const float *b, size_t elementCount) { + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + size_t i = 0; + size_t unrolled_limit = elementCount & ~63UL; + // Each __m512 holds 16 floats, so unroll 4x = 64 floats per loop + for (; i < unrolled_limit; i += 64) { + __m512 d0 = _mm512_sub_ps(_mm512_loadu_ps(a + i), _mm512_loadu_ps(b + i)); + __m512 d1 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 16), _mm512_loadu_ps(b + i + 16)); + __m512 d2 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 32), _mm512_loadu_ps(b + i + 32)); + __m512 d3 = _mm512_sub_ps(_mm512_loadu_ps(a + i + 48), _mm512_loadu_ps(b + i + 48)); + + sum0 = _mm512_fmadd_ps(d0, d0, sum0); + sum1 = _mm512_fmadd_ps(d1, d1, sum1); + sum2 = _mm512_fmadd_ps(d2, d2, sum2); + sum3 = _mm512_fmadd_ps(d3, d3, sum3); + } + + // reduce all partial sums + __m512 total_sum = _mm512_add_ps(_mm512_add_ps(sum0, sum1), _mm512_add_ps(sum2, sum3)); + float result = _mm512_reduce_add_ps(total_sum); + + for (; i < elementCount; ++i) { + float diff = a[i] - b[i]; + result += diff * diff; + } + + return result; +} + #ifdef __clang__ #pragma clang attribute pop #elif __GNUC__ diff --git a/libs/simdvec/native/src/vec/headers/vec.h b/libs/simdvec/native/src/vec/headers/vec.h index e27e9a3a68083..733aea3165659 100644 --- a/libs/simdvec/native/src/vec/headers/vec.h +++ b/libs/simdvec/native/src/vec/headers/vec.h @@ -20,3 +20,10 @@ EXPORT int vec_caps(); EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims); EXPORT int32_t sqr7u(int8_t *a, int8_t *b, size_t length); + +EXPORT float cosf32(const float *a, const float *b, size_t elementCount); + +EXPORT float dotf32(const float *a, const float *b, size_t elementCount); + +EXPORT float sqrf32(const float *a, const float *b, size_t elementCount); + diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java index 4ed60b2f5e8b2..cec456d70afb7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactory.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -53,4 +54,32 @@ Optional getInt7SQVectorScorerSupplier( * @return an optional containing the vector scorer, or empty */ Optional getInt7SQVectorScorer(VectorSimilarityFunction sim, QuantizedByteVectorValues values, float[] queryVector); + + /** + * Returns an optional containing a float32 vector scorer for the given + * parameters, or an empty optional if a scorer is not supported. + * + * @param similarityType the similarity type + * @param input the index input containing the vector data; + * offset of the first vector is 0, + * the length must be maxOrd * dims * Float#BYTES. + * @param values the random access vector values + * @return an optional containing the vector scorer, or empty + */ + Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ); + + /** + * Returns an optional containing a float32 vector scorer for the given + * parameters, or an empty optional if a scorer is not supported. + * + * @param sim the similarity type + * @param values the random access vector values + * @param queryVector the query vector + * @return an optional containing the vector scorer, or empty + */ + Optional getFloat32VectorScorer(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index 6248902c32e7a..f43bc1062e6c7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -9,6 +9,7 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; @@ -39,4 +40,22 @@ public Optional getInt7SQVectorScorer( ) { throw new UnsupportedOperationException("should not reach here"); } + + @Override + public Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + throw new UnsupportedOperationException("should not reach here"); + } + + @Override + public Optional getFloat32VectorScorer( + VectorSimilarityFunction sim, + FloatVectorValues values, + float[] queryVector + ) { + throw new UnsupportedOperationException("should not reach here"); + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java index a863d9e3448ca..7c45ac04c59b8 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/VectorScorerFactoryImpl.java @@ -9,18 +9,17 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.elasticsearch.nativeaccess.NativeAccess; +import org.elasticsearch.simdvec.internal.Float32VectorScorer; +import org.elasticsearch.simdvec.internal.Float32VectorScorerSupplier; import org.elasticsearch.simdvec.internal.Int7SQVectorScorer; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.DotProductSupplier; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.EuclideanSupplier; -import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier.MaxInnerProductSupplier; +import org.elasticsearch.simdvec.internal.Int7SQVectorScorerSupplier; import java.util.Optional; @@ -41,17 +40,7 @@ public Optional getInt7SQVectorScorerSupplier( QuantizedByteVectorValues values, float scoreCorrectionConstant ) { - input = FilterIndexInput.unwrapOnlyTest(input); - if (input instanceof MemorySegmentAccessInput == false) { - return Optional.empty(); - } - MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; - checkInvariants(values.size(), values.dimension(), input); - return switch (similarityType) { - case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values, scoreCorrectionConstant)); - case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, scoreCorrectionConstant)); - case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values, scoreCorrectionConstant)); - }; + return Int7SQVectorScorerSupplier.create(similarityType, input, values, scoreCorrectionConstant); } @Override @@ -63,9 +52,28 @@ public Optional getInt7SQVectorScorer( return Int7SQVectorScorer.create(sim, values, queryVector); } - static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { - if (input.length() < (long) vectorByteLength * maxOrd) { - throw new IllegalArgumentException("input length is less than expected vector data"); - } + // -- floats + + @Override + public Optional getFloat32VectorScorerSupplier( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + return Float32VectorScorerSupplier.create(similarityType, input, values); + } + + @Override + public Optional getFloat32VectorScorer( + VectorSimilarityFunction sim, + FloatVectorValues values, + float[] queryVector + ) { + return Float32VectorScorer.create(sim, values, queryVector); + } + + @Override + public String toString() { + return "VectorScorerFactoryImpl"; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java new file mode 100644 index 0000000000000..89a0ea0df8eeb --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.util.Optional; + +public class Float32VectorScorer { + + // Unconditionally returns an empty optional on <= JDK 21, since the scorer is only supported on JDK 22+ + public static Optional create(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector) { + return Optional.empty(); + } + + private Float32VectorScorer() {} +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java new file mode 100644 index 0000000000000..74f7b43ba41c8 --- /dev/null +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Float32VectorScorerSupplier.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +public abstract sealed class Float32VectorScorerSupplier implements RandomVectorScorerSupplier { + + public static Optional create( + VectorSimilarityType similarityType, + IndexInput input, + FloatVectorValues values + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE -> Optional.of(new CosineSupplier(msInput, values)); + case DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values)); + }; + } + + final int dims; + final int vectorByteSize; + final int maxOrd; + final MemorySegmentAccessInput input; + final FloatVectorValues values; // to support ordToDoc/getAcceptOrds + final VectorSimilarityFunction fallbackScorer; + + protected Float32VectorScorerSupplier( + MemorySegmentAccessInput input, + FloatVectorValues values, + VectorSimilarityFunction fallbackScorer + ) { + this.input = input; + this.values = values; + this.dims = values.dimension(); + this.vectorByteSize = values.getVectorByteLength(); + this.maxOrd = values.size(); + this.fallbackScorer = fallbackScorer; + } + + protected final void checkOrdinal(int ord) { + if (ord < 0 || ord > maxOrd) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + final float scoreFromOrds(int firstOrd, int secondOrd) throws IOException { + long firstByteOffset = (long) firstOrd * vectorByteSize; + long secondByteOffset = (long) secondOrd * vectorByteSize; + + MemorySegment firstSeg = input.segmentSliceOrNull(firstByteOffset, vectorByteSize); + if (firstSeg == null) { + return fallbackScore(firstByteOffset, secondByteOffset); + } + + MemorySegment secondSeg = input.segmentSliceOrNull(secondByteOffset, vectorByteSize); + if (secondSeg == null) { + return fallbackScore(firstByteOffset, secondByteOffset); + } + + return scoreFromSegments(firstSeg, secondSeg); + } + + abstract float scoreFromSegments(MemorySegment a, MemorySegment b); + + protected final float fallbackScore(long firstByteOffset, long secondByteOffset) throws IOException { + float[] a = new float[dims]; + readFloats(a, firstByteOffset, dims); + + float[] b = new float[dims]; + readFloats(b, secondByteOffset, dims); + + return fallbackScorer.compare(a, b); + } + + final void readFloats(float[] floats, long byteOffset, int len) throws IOException { + for (int i = 0; i < len; i++) { + floats[i] = Float.intBitsToFloat(input.readInt(byteOffset)); + byteOffset += Float.BYTES; + } + } + + @Override + public UpdateableRandomVectorScorer scorer() { + return new UpdateableRandomVectorScorer.AbstractUpdateableRandomVectorScorer(values) { + private int ord = -1; + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + return scoreFromOrds(ord, node); + } + + @Override + public void setScoringOrdinal(int node) throws IOException { + checkOrdinal(node); + this.ord = node; + } + }; + } + + public static final class CosineSupplier extends Float32VectorScorerSupplier { + + public CosineSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.COSINE); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.cosineFloat32(a, b, dims); + return Math.max((1 + v) / 2, 0); + } + + @Override + public CosineSupplier copy() { + return new CosineSupplier(input.clone(), values); + } + } + + public static final class EuclideanSupplier extends Float32VectorScorerSupplier { + + public EuclideanSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.EUCLIDEAN); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.squareDistanceFloat32(a, b, dims); + return 1 / (1 + v); + } + + @Override + public EuclideanSupplier copy() { + return new EuclideanSupplier(input.clone(), values); + } + } + + public static final class DotProductSupplier extends Float32VectorScorerSupplier { + + public DotProductSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.DOT_PRODUCT); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float dotProduct = Similarities.dotProductFloat32(a, b, dims); + return Math.max((1 + dotProduct) / 2, 0); + } + + @Override + public DotProductSupplier copy() { + return new DotProductSupplier(input.clone(), values); + } + } + + public static final class MaxInnerProductSupplier extends Float32VectorScorerSupplier { + + public MaxInnerProductSupplier(MemorySegmentAccessInput input, FloatVectorValues values) { + super(input, values, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + } + + @Override + float scoreFromSegments(MemorySegment a, MemorySegment b) { + float v = Similarities.dotProductFloat32(a, b, dims); + return VectorUtil.scaleMaxInnerProductScore(v); + } + + @Override + public MaxInnerProductSupplier copy() { + return new MaxInnerProductSupplier(input.clone(), values); + } + } + + static boolean checkIndex(long index, long length) { + return index >= 0 && index < length; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } +} diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java index 19f33ba1c71f7..83ee1532472c2 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java @@ -9,14 +9,18 @@ package org.elasticsearch.simdvec.internal; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; +import org.elasticsearch.simdvec.VectorSimilarityType; import java.io.IOException; import java.lang.foreign.MemorySegment; +import java.util.Optional; import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT; import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; @@ -27,6 +31,25 @@ public abstract sealed class Int7SQVectorScorerSupplier implements RandomVectorS static final byte BITS = 7; + public static Optional create( + VectorSimilarityType similarityType, + IndexInput input, + QuantizedByteVectorValues values, + float scoreCorrectionConstant + ) { + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.dimension(), input); + return switch (similarityType) { + case COSINE, DOT_PRODUCT -> Optional.of(new DotProductSupplier(msInput, values, scoreCorrectionConstant)); + case EUCLIDEAN -> Optional.of(new EuclideanSupplier(msInput, values, scoreCorrectionConstant)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductSupplier(msInput, values, scoreCorrectionConstant)); + }; + } + final int dims; final int maxOrd; final float scoreCorrectionConstant; @@ -172,4 +195,10 @@ public MaxInnerProductSupplier copy() { static boolean checkIndex(long index, long length) { return index >= 0 && index < length; } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java index 482bbc8d8cabe..565b92892dbb3 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java @@ -23,6 +23,9 @@ public class Similarities { static final MethodHandle DOT_PRODUCT_7U = DISTANCE_FUNCS.dotProductHandle7u(); static final MethodHandle SQUARE_DISTANCE_7U = DISTANCE_FUNCS.squareDistanceHandle7u(); + static final MethodHandle COS_PRODUCT_FLOAT32 = DISTANCE_FUNCS.cosineHandleFloat32(); + static final MethodHandle DOT_PRODUCT_FLOAT32 = DISTANCE_FUNCS.dotProductHandleFloat32(); + static final MethodHandle SQR_PRODUCT_FLOAT32 = DISTANCE_FUNCS.squareDistanceHandleFloat32(); static int dotProduct7u(MemorySegment a, MemorySegment b, int length) { try { @@ -51,4 +54,46 @@ static int squareDistance7u(MemorySegment a, MemorySegment b, int length) { } } } + + static float cosineFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) COS_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float dotProductFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) DOT_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + static float squareDistanceFloat32(MemorySegment a, MemorySegment b, int elementCount) { + try { + return (float) SQR_PRODUCT_FLOAT32.invokeExact(a, b, elementCount); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } } diff --git a/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java new file mode 100644 index 0000000000000..3dd1a179dd6dd --- /dev/null +++ b/libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Float32VectorScorer.java @@ -0,0 +1,150 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec.internal; + +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.FilterIndexInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.MemorySegmentAccessInput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +import java.io.IOException; +import java.lang.foreign.MemorySegment; +import java.util.Optional; + +public abstract sealed class Float32VectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + + final int dims; + final int vectorByteSize; + final MemorySegmentAccessInput input; + final MemorySegment query; + byte[] scratch; + + /** Return an optional whose value, if present, is the scorer. Otherwise, an empty optional is returned. */ + public static Optional create(VectorSimilarityFunction sim, FloatVectorValues values, float[] queryVector) { + checkDimensions(queryVector.length, values.dimension()); + IndexInput input = null; + if (values instanceof HasIndexSlice hasIndexSlice) { + input = hasIndexSlice.getSlice(); + } + if (input == null) { + return Optional.empty(); + } + input = FilterIndexInput.unwrapOnlyTest(input); + if (input instanceof MemorySegmentAccessInput == false) { + return Optional.empty(); + } + MemorySegmentAccessInput msInput = (MemorySegmentAccessInput) input; + checkInvariants(values.size(), values.getVectorByteLength(), input); + + return switch (sim) { + case COSINE -> Optional.of(new CosineScorer(msInput, values, queryVector)); + case DOT_PRODUCT -> Optional.of(new DotProductScorer(msInput, values, queryVector)); + case EUCLIDEAN -> Optional.of(new EuclideanScorer(msInput, values, queryVector)); + case MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(msInput, values, queryVector)); + }; + } + + Float32VectorScorer(MemorySegmentAccessInput input, FloatVectorValues values, float[] queryVector) { + super(values); + this.input = input; + assert queryVector.length * Float.BYTES == values.getVectorByteLength() && queryVector.length == values.dimension(); + this.vectorByteSize = values.getVectorByteLength(); + this.dims = values.dimension(); + this.query = MemorySegment.ofArray(queryVector); + } + + final MemorySegment getSegment(int ord) throws IOException { + checkOrdinal(ord); + long byteOffset = (long) ord * vectorByteSize; + MemorySegment seg = input.segmentSliceOrNull(byteOffset, vectorByteSize); + if (seg == null) { + if (scratch == null) { + scratch = new byte[vectorByteSize]; + } + input.readBytes(byteOffset, scratch, 0, vectorByteSize); + seg = MemorySegment.ofArray(scratch); + } + return seg; + } + + static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) { + if (input.length() < (long) vectorByteLength * maxOrd) { + throw new IllegalArgumentException("input length is less than expected vector data"); + } + } + + final void checkOrdinal(int ord) { + if (ord < 0 || ord >= maxOrd()) { + throw new IllegalArgumentException("illegal ordinal: " + ord); + } + } + + public static final class CosineScorer extends Float32VectorScorer { + public CosineScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = Similarities.cosineFloat32(query, getSegment(node), dims); + return Math.max((1 + dotProduct) / 2, 0); + } + } + + public static final class DotProductScorer extends Float32VectorScorer { + public DotProductScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float dotProduct = Similarities.dotProductFloat32(query, getSegment(node), dims); + return Math.max((1 + dotProduct) / 2, 0); + } + } + + public static final class EuclideanScorer extends Float32VectorScorer { + public EuclideanScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float v = Similarities.squareDistanceFloat32(query, getSegment(node), dims); + return 1 / (1 + v); + } + } + + public static final class MaxInnerProductScorer extends Float32VectorScorer { + public MaxInnerProductScorer(MemorySegmentAccessInput in, FloatVectorValues values, float[] query) { + super(in, values, query); + } + + @Override + public float score(int node) throws IOException { + checkOrdinal(node); + float v = Similarities.dotProductFloat32(query, getSegment(node), dims); + return VectorUtil.scaleMaxInnerProductScore(v); + } + } + + static void checkDimensions(int queryLen, int fieldLen) { + if (queryLen != fieldLen) { + throw new IllegalArgumentException("vector query dimension: " + queryLen + " differs from field dimension: " + fieldLen); + } + } +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java index 31c5daa81f92b..3905c462d29ca 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/AbstractVectorTestCase.java @@ -9,7 +9,6 @@ package org.elasticsearch.simdvec; -import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.elasticsearch.test.ESTestCase; import org.junit.BeforeClass; @@ -62,17 +61,9 @@ public static String platformMsg() { return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch; } - /** Computes the score using the Lucene implementation. */ - public static float luceneScore( - VectorSimilarityType similarityFunc, - byte[] a, - byte[] b, - float correction, - float aOffsetValue, - float bOffsetValue - ) { - var scorer = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityType.of(similarityFunc), correction, (byte) 7); - return scorer.score(a, aOffsetValue, b, bOffsetValue); + // Support for passing on-heap arrays/segments to native + protected static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; } /** Converts a float value to a byte array. */ diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java new file mode 100644 index 0000000000000..27740e4ed43f0 --- /dev/null +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Float32VectorScorerFactoryTests.java @@ -0,0 +1,409 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.simdvec; + +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import java.util.stream.IntStream; + +import static org.elasticsearch.simdvec.VectorSimilarityType.COSINE; +import static org.elasticsearch.simdvec.VectorSimilarityType.DOT_PRODUCT; +import static org.elasticsearch.simdvec.VectorSimilarityType.EUCLIDEAN; +import static org.elasticsearch.simdvec.VectorSimilarityType.MAXIMUM_INNER_PRODUCT; +import static org.elasticsearch.test.hamcrest.OptionalMatchers.isEmpty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +public class Float32VectorScorerFactoryTests extends AbstractVectorTestCase { + + static final double BASE_DELTA = 1e-5; + + // Tests that the provider instance is present or not on expected platforms/architectures + public void testSupport() { + supported(); + } + + public void testSimple() throws IOException { + testSimpleImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testSimpleMaxChunkSizeSmall() throws IOException { + long maxChunkSize = randomLongBetween(4, 16); + logger.info("maxChunkSize=" + maxChunkSize); + testSimpleImpl(maxChunkSize); + } + + void testSimpleImpl(long maxChunkSize) throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testSimpleImpl"), maxChunkSize)) { + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + for (int dims : List.of(31, 32, 33)) { + // dimensions that cross the scalar / native boundary (stride) + float[] vector1 = new float[dims]; + float[] vector2 = new float[dims]; + String fileName = "testSimpleImpl-" + sim + "-" + dims + ".vex"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < dims; i++) { + vector1[i] = (float) i; + vector2[i] = (float) (dims - i); + } + writeFloat32Vectors(out, vector1, vector2); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 2, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vector1, vector2); + + var luceneSupplier = luceneScoreSupplier(values, VectorSimilarityType.of(sim)).scorer(); + luceneSupplier.setScoringOrdinal(0); + assertEquals(expected, luceneSupplier.score(1), BASE_DELTA); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertEquals(expected, scorer.score(1), BASE_DELTA); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vector1).get(); + assertEquals(expected, qScorer.score(1), BASE_DELTA); + } + } + } + } + } + } + + public void testNonNegative() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testNonNegativeDotProduct"), MMapDirectory.DEFAULT_MAX_CHUNK_SIZE)) { + float[] vec1 = new float[32]; + float[] vec2 = new float[32]; + String fileName = "testNonNegativeDotProduct-32"; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + vec1[0] = -2.0f; // values to trigger a negative dot product + vec2[0] = 1.0f; + writeFloat32Vectors(out, vec1, vec2); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(32, 2, in, VectorSimilarityType.of(DOT_PRODUCT)); + // dot product + float expected = 0.0f; + assertThat(luceneFloat32Score(DOT_PRODUCT, vec1, vec2), equalTo(expected)); + var supplier = factory.getFloat32VectorScorerSupplier(DOT_PRODUCT, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + // max inner product + expected = luceneFloat32Score(MAXIMUM_INNER_PRODUCT, vec1, vec2); + supplier = factory.getFloat32VectorScorerSupplier(MAXIMUM_INNER_PRODUCT, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + assertThat(scorer.score(1), equalTo(expected)); + // cosine + expected = 0f; + assertThat(luceneFloat32Score(COSINE, vec1, vec2), equalTo(expected)); + supplier = factory.getFloat32VectorScorerSupplier(COSINE, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + // euclidean + expected = luceneFloat32Score(EUCLIDEAN, vec1, vec2); + supplier = factory.getFloat32VectorScorerSupplier(EUCLIDEAN, in, values).get(); + scorer = supplier.scorer(); + scorer.setScoringOrdinal(0); + assertThat(scorer.score(1), equalTo(expected)); + assertThat(scorer.score(1), greaterThanOrEqualTo(0f)); + } + } + } + + public void testRandom() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomWithChunkSize(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE); + } + + public void testRandomMaxChunkSizeSmall() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + long maxChunkSize = randomLongBetween(32, 128); + logger.info("maxChunkSize=" + maxChunkSize); + testRandomWithChunkSize(maxChunkSize); + } + + void testRandomWithChunkSize(long maxChunkSize) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandom"), maxChunkSize)) { + final int dims = randomIntBetween(1, 4096); + final int size = randomIntBetween(2, 100); + final float[][] vectors = IntStream.range(0, size).mapToObj(i -> randomVector(dims)).toArray(float[][]::new); + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testRandom-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeFloat32Vectors(out, vectors); + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vectors[idx0], vectors[idx1]); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + + public void testRandomSlice() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + testRandomSliceImpl(30, 64, 1); + } + + void testRandomSliceImpl(int dims, long maxChunkSize, int initialPadding) throws IOException { + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testRandomSliceImpl"), maxChunkSize)) { + for (int times = 0; times < TIMES; times++) { + final int size = randomIntBetween(2, 100); + final float[][] vectors = IntStream.range(0, size).mapToObj(i -> randomVector(dims)).toArray(float[][]::new); + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testRandomSliceImpl-" + times + "-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + byte[] ba = new byte[initialPadding]; + out.writeBytes(ba, 0, ba.length); + writeFloat32Vectors(out, vectors); + } + try ( + var outter = dir.openInput(fileName, IOContext.DEFAULT); + var in = outter.slice("slice", initialPadding, outter.length() - initialPadding) + ) { + for (int itrs = 0; itrs < TIMES / 10; itrs++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = randomIntBetween(0, size - 1); // may be the same as idx0 - which is ok. + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vectors[idx0], vectors[idx1]); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vectors[idx0]).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + } + + // Tests with a large amount of data (> 2GB), which ensures that data offsets do not overflow + @Nightly + public void testLarge() throws IOException { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + try (Directory dir = new MMapDirectory(createTempDir("testLarge"))) { + final int dims = 4096; + final int size = 262144; + assert (long) dims * Float.BYTES * size > Integer.MAX_VALUE; + final double delta = BASE_DELTA * dims; // scale the delta with the size + + String fileName = "testLarge-" + dims; + logger.info("Testing " + fileName); + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + for (int i = 0; i < size; i++) { + writeFloat32Vectors(out, vector(i, dims)); + } + } + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + for (int times = 0; times < TIMES; times++) { + int idx0 = randomIntBetween(0, size - 1); + int idx1 = size - 1; + for (var sim : List.of(COSINE, DOT_PRODUCT, EUCLIDEAN, MAXIMUM_INNER_PRODUCT)) { + var values = vectorValues(dims, size, in, VectorSimilarityType.of(sim)); + float expected = luceneFloat32Score(sim, vector(idx0, dims), vector(idx1, dims)); + var supplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var scorer = supplier.scorer(); + scorer.setScoringOrdinal(idx0); + assertEquals(expected, scorer.score(idx1), delta); + + if (supportsHeapSegments()) { + var qScorer = factory.getFloat32VectorScorer(VectorSimilarityType.of(sim), values, vector(idx0, dims)).get(); + assertEquals(expected, qScorer.score(idx1), delta); + } + } + } + } + } + } + + public void testRace() throws Exception { + testRaceImpl(COSINE); + testRaceImpl(DOT_PRODUCT); + testRaceImpl(EUCLIDEAN); + testRaceImpl(MAXIMUM_INNER_PRODUCT); + } + + // Tests that copies in threads do not interfere with each other + void testRaceImpl(VectorSimilarityType sim) throws Exception { + assumeTrue(notSupportedMsg(), supported()); + var factory = AbstractVectorTestCase.factory.get(); + + final long maxChunkSize = 32; + final int dims = 34; // dimensions that are larger than the chunk size, to force fallback + float[] vec1 = new float[dims]; + float[] vec2 = new float[dims]; + IntStream.range(0, dims).forEach(i -> vec1[i] = 1); + IntStream.range(0, dims).forEach(i -> vec2[i] = 2); + try (Directory dir = new MMapDirectory(createTempDir("testRace"), maxChunkSize)) { + String fileName = "testRace-" + dims; + try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) { + writeFloat32Vectors(out, vec1, vec1, vec2, vec2); + } + var expectedScore1 = luceneFloat32Score(sim, vec1, vec1); + var expectedScore2 = luceneFloat32Score(sim, vec2, vec2); + + try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { + var values = vectorValues(dims, 4, in, VectorSimilarityType.of(sim)); + var scoreSupplier = factory.getFloat32VectorScorerSupplier(sim, in, values).get(); + var tasks = List.>>of( + new ScoreCallable(scoreSupplier.copy().scorer(), 0, 1, expectedScore1), + new ScoreCallable(scoreSupplier.copy().scorer(), 2, 3, expectedScore2) + ); + var executor = Executors.newFixedThreadPool(2); + var results = executor.invokeAll(tasks); + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertThat(results.stream().filter(Predicate.not(Future::isDone)).count(), equalTo(0L)); + for (var res : results) { + assertThat("Unexpected exception" + res.get(), res.get(), isEmpty()); + } + } + } + } + + static class ScoreCallable implements Callable> { + + final UpdateableRandomVectorScorer scorer; + final int ord; + final float expectedScore; + + ScoreCallable(UpdateableRandomVectorScorer scorer, int queryOrd, int ord, float expectedScore) { + try { + this.scorer = scorer; + this.scorer.setScoringOrdinal(queryOrd); + this.ord = ord; + this.expectedScore = expectedScore; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public Optional call() { + try { + for (int i = 0; i < 100; i++) { + assertThat(scorer.score(ord), equalTo(expectedScore)); + } + } catch (Throwable t) { + return Optional.of(t); + } + return Optional.empty(); + } + } + + FloatVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + var slice = in.slice("values", 0, in.length()); + var byteSize = dims * Float.BYTES; + return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, slice, byteSize, DefaultFlatVectorScorer.INSTANCE, sim); + } + + /** Computes the score using the Lucene implementation. */ + static float luceneFloat32Score(VectorSimilarityType similarityFunc, float[] a, float[] b) { + return VectorSimilarityType.of(similarityFunc).compare(a, b); + } + + RandomVectorScorerSupplier luceneScoreSupplier(KnnVectorValues values, VectorSimilarityFunction sim) throws IOException { + return DefaultFlatVectorScorer.INSTANCE.getRandomVectorScorerSupplier(sim, values); + } + + static float[] randomVector(int dim) { + float[] v = new float[dim]; + Random random = random(); + for (int i = 0; i < dim; i++) { + v[i] = random.nextFloat(); + } + return v; + } + + // creates the vector based on the given ordinal, which is reproducible given the ord and dims + static float[] vector(int ord, int dims) { + var random = new Random(Objects.hash(ord, dims)); + float[] fa = new float[dims]; + for (int i = 0; i < dims; i++) { + fa[i] = random.nextFloat(); + } + return fa; + } + + static void writeFloat32Vectors(IndexOutput out, float[]... vectors) throws IOException { + var buffer = ByteBuffer.allocate(vectors[0].length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (var v : vectors) { + buffer.asFloatBuffer().put(v); + out.writeBytes(buffer.array(), buffer.array().length); + } + } + + static final int TIMES = 100; // a loop iteration times +} diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java similarity index 96% rename from libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java rename to libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java index 070260759d6a0..eca67582b7a16 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/VectorScorerFactoryTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/Int7SQVectorScorerFactoryTests.java @@ -22,6 +22,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; import java.io.IOException; @@ -47,8 +48,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; -// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 100) -public class VectorScorerFactoryTests extends AbstractVectorTestCase { +public class Int7SQVectorScorerFactoryTests extends AbstractVectorTestCase { // bounds of the range of values that can be seen by int7 scalar quantized vectors static final byte MIN_INT7_VALUE = 0; @@ -107,7 +107,7 @@ void testSimpleImpl(long maxChunkSize) throws IOException { scorer.setScoringOrdinal(0); assertThat(scorer.score(1), equalTo(expected)); - if (Runtime.version().feature() >= 22) { + if (supportsHeapSegments()) { var qScorer = factory.getInt7SQVectorScorer(VectorSimilarityType.of(sim), values, query1).get(); assertThat(qScorer.score(1), equalTo(expected)); } @@ -229,11 +229,11 @@ void testRandomSupplier(long maxChunkSize, Function byteArraySu } public void testRandomScorer() throws IOException { - testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_RANDOM_FUNC); } public void testRandomScorerMax() throws IOException { - testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, VectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC); + testRandomScorerImpl(MMapDirectory.DEFAULT_MAX_CHUNK_SIZE, Int7SQVectorScorerFactoryTests.FLOAT_ARRAY_MAX_FUNC); } public void testRandomScorerChunkSizeSmall() throws IOException { @@ -461,6 +461,19 @@ QuantizedByteVectorValues vectorValues(int dims, int size, IndexInput in, Vector return new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(dims, size, sq, false, sim, null, slice); } + /** Computes the score using the Lucene implementation. */ + public static float luceneScore( + VectorSimilarityType similarityFunc, + byte[] a, + byte[] b, + float correction, + float aOffsetValue, + float bOffsetValue + ) { + var scorer = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(VectorSimilarityType.of(similarityFunc), correction, (byte) 7); + return scorer.score(a, aOffsetValue, b, bOffsetValue); + } + RandomVectorScorerSupplier luceneScoreSupplier(QuantizedByteVectorValues values, VectorSimilarityFunction sim) throws IOException { return new Lucene99ScalarQuantizedVectorScorer(null).getRandomVectorScorerSupplier(sim, values); } diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index e6a44a0fba7a4..843780139ff8a 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -457,6 +457,7 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat, org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index 29f62b64764a9..3281674316825 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -41,7 +41,7 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat { static final String NAME = "ES813FlatVectorFormat"; - private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE); + private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(ESFlatVectorsScorer.create()); /** * Sole constructor diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java new file mode 100644 index 0000000000000..6c82d75348ffd --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ESFlatVectorsScorer.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.simdvec.VectorScorerFactory; +import org.elasticsearch.simdvec.VectorSimilarityType; + +import java.io.IOException; +import java.util.Objects; + +/** + * Flat vectors scorer that wraps a given delegate. This scorer first checks the Elasticsearch + * simdvec scorer factory for an optimized implementation before falling back to the delegate + * if one does not exist. + * + *

The current implementation checks for optimized simdvec float32 scorers. + */ +public final class ESFlatVectorsScorer implements FlatVectorsScorer { + + static final VectorScorerFactory VECTOR_SCORER_FACTORY = VectorScorerFactory.instance().orElse(null); + + final FlatVectorsScorer delegate; + + /** Creates a FlatVectorsScorer delegating to Lucene's default scorer {@link FlatVectorScorerUtil#getLucene99FlatVectorsScorer()} + * if the platform does not have an optimized scorer factory. + */ + public static FlatVectorsScorer create() { + return create(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + } + + /** Creates a FlatVectorsScorer. Returns the delegate if the platform does not have an optimized scorer factory. */ + static FlatVectorsScorer create(FlatVectorsScorer delegate) { + Objects.requireNonNull(delegate); + if (VECTOR_SCORER_FACTORY == null) { + return delegate; + } + return new ESFlatVectorsScorer(delegate); + } + + private ESFlatVectorsScorer(FlatVectorsScorer delegate) { + this.delegate = delegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction sim, KnnVectorValues values) + throws IOException { + assert VECTOR_SCORER_FACTORY != null; + if (values instanceof FloatVectorValues fValues && fValues instanceof HasIndexSlice sliceable) { + if (sliceable.getSlice() != null) { + var scorer = VECTOR_SCORER_FACTORY.getFloat32VectorScorerSupplier( + VectorSimilarityType.of(sim), + sliceable.getSlice(), + fValues + ); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return delegate.getRandomVectorScorerSupplier(sim, values); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, float[] query) + throws IOException { + assert VECTOR_SCORER_FACTORY != null; + if (values instanceof FloatVectorValues fValues && fValues instanceof HasIndexSlice sliceable) { + if (sliceable.getSlice() != null) { + var scorer = VECTOR_SCORER_FACTORY.getFloat32VectorScorer(sim, fValues, query); + if (scorer.isPresent()) { + return scorer.get(); + } + } + } + return delegate.getRandomVectorScorer(sim, values, query); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction sim, KnnVectorValues values, byte[] query) throws IOException { + assert VECTOR_SCORER_FACTORY != null; + return delegate.getRandomVectorScorer(sim, values, query); + } + + @Override + public String toString() { + return "ESFlatVectorsScorer(" + "delegate=" + delegate + ", factory=" + VECTOR_SCORER_FACTORY + ')'; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java new file mode 100644 index 0000000000000..f9ec1089530d7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es819; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.index.codec.vectors.ESFlatVectorsScorer; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +// Minimal copy of Lucene99HnswVectorsFormat in order to provide an optimized scorer, +// which returns identical scores to that of the default flat vector scorer. +public class ES819HnswVectorsFormat extends KnnVectorsFormat { + + static final String NAME = "ES819HnswVectorsFormat"; + + static final int MAXIMUM_MAX_CONN = Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + static final int MAXIMUM_BEAM_WIDTH = Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; + + static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(ESFlatVectorsScorer.create()); + + private final int maxConn; + private final int beamWidth; + + public ES819HnswVectorsFormat() { + this(Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); + } + + public ES819HnswVectorsFormat(int maxConn, int beamWidth) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public String toString() { + return "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 329e426be7f47..d1b17e47327b9 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -52,6 +52,7 @@ import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -2048,7 +2049,7 @@ public KnnVectorsFormat getVectorsFormat(ElementType elementType) { if (elementType == ElementType.BIT) { return new ES815HnswBitVectorsFormat(m, efConstruction); } - return new Lucene99HnswVectorsFormat(m, efConstruction, 1, null); + return new ES819HnswVectorsFormat(m, efConstruction); } @Override diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 14e68029abc3b..b79f746f94d95 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,4 +7,5 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es819.ES819HnswVectorsFormat org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java new file mode 100644 index 0000000000000..78babb624dbc7 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormatTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es819; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.nativeaccess.NativeAccess; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES819HnswVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + static final boolean optimizedScorer = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + + static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES819HnswVectorsFormat()); + + @Override + protected Codec getCodec() { + return codec; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES819HnswVectorsFormat(10, 20); + } + }; + var expectedScorer = optimizedScorer + ? "ESFlatVectorsScorer(delegate=%luceneScorer%, factory=VectorScorerFactoryImpl)" + : "%luceneScorer%"; + String expectedPattern = "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%scorer%))".replace("%scorer%", expectedScorer); + + var defaultScorer = expectedPattern.replace("%luceneScorer%", "DefaultFlatVectorScorer()"); + var memSegScorer = expectedPattern.replace("%luceneScorer%", "Lucene99MemorySegmentFlatVectorsScorer()"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } +} From 876da01f4e9053c245a8d47efb0bf089d35083d4 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 3 Jul 2025 11:03:49 +0000 Subject: [PATCH 2/6] [CI] Auto commit changes from spotless --- .../elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java | 1 - .../index/codec/vectors/es819/ES819HnswVectorsFormat.java | 1 - 2 files changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java index 3281674316825..206afbc857941 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java @@ -13,7 +13,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java index f9ec1089530d7..67945567ac4b9 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es819/ES819HnswVectorsFormat.java @@ -12,7 +12,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; From df2712590e8f562f80568246a43ca15e3412b7fb Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Thu, 3 Jul 2025 13:31:55 +0100 Subject: [PATCH 3/6] use built artifact --- libs/native/libraries/build.gradle | 4 ++-- libs/simdvec/native/publish_vec_binaries.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index e00090e6df4c4..53b187fbe960d 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -19,7 +19,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.11" +var vecVersion = "1.0.12" repositories { exclusiveContent { @@ -52,7 +52,7 @@ dependencies { libs "org.elasticsearch:zstd:${zstdVersion}:linux-aarch64" libs "org.elasticsearch:zstd:${zstdVersion}:linux-x86-64" libs "org.elasticsearch:zstd:${zstdVersion}:windows-x86-64" -// libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib + libs "org.elasticsearch:vec:${vecVersion}@zip" // temporarily comment this out, if testing a locally built native lib } def extractLibs = tasks.register('extractLibs', Copy) { diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index e3d7e4858ecfc..2dd2b4461b3de 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.11" +VERSION="1.0.12" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) From d936652237fd1e79b7e1eb2a4b97a16cf2bda961 Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Thu, 3 Jul 2025 14:06:38 +0100 Subject: [PATCH 4/6] test itr --- .../mapper/vectors/DenseVectorFieldMapperTests.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index 0f161d4a1e44f..00d0f9cb9323f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -47,6 +47,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.search.lookup.Source; import org.elasticsearch.search.lookup.SourceProvider; import org.elasticsearch.search.vectors.VectorData; @@ -2621,6 +2622,8 @@ public void testFloatVectorQueryBoundaries() throws IOException { ); } + static final boolean optimizedScorer = NativeAccess.instance().getVectorSimilarityFunctions().isPresent(); + public void testKnnVectorsFormat() throws IOException { final int m = randomIntBetween(1, DEFAULT_MAX_CONN + 10); final int efConstruction = randomIntBetween(1, DEFAULT_BEAM_WIDTH + 10); @@ -2654,11 +2657,14 @@ public void testKnnVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=" + var expectedScorer = optimizedScorer + ? "ESFlatVectorsScorer(delegate=DefaultFlatVectorScorer(), factory=VectorScorerFactoryImpl)" + : "DefaultFlatVectorScorer()"; + String expectedString = "ES819HnswVectorsFormat(name=ES819HnswVectorsFormat, maxConn=" + (setM ? m : DEFAULT_MAX_CONN) + ", beamWidth=" + (setEfConstruction ? efConstruction : DEFAULT_BEAM_WIDTH) - + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())" + + ", flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%scorer%)".replace("%scorer%", expectedScorer) + ")"; assertEquals(expectedString, knnVectorsFormat.toString()); } From e1a4d779d69da944e5198210e8c4ba00a3cdcad0 Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Fri, 4 Jul 2025 10:02:29 +0100 Subject: [PATCH 5/6] horizontal_sum_avx2 -> hsum_f32_8 --- libs/native/libraries/build.gradle | 2 +- libs/simdvec/native/publish_vec_binaries.sh | 2 +- libs/simdvec/native/src/vec/c/amd64/vec.c | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/libs/native/libraries/build.gradle b/libs/native/libraries/build.gradle index 53b187fbe960d..4d94ad6e20c73 100644 --- a/libs/native/libraries/build.gradle +++ b/libs/native/libraries/build.gradle @@ -19,7 +19,7 @@ configurations { } var zstdVersion = "1.5.5" -var vecVersion = "1.0.12" +var vecVersion = "1.0.13" repositories { exclusiveContent { diff --git a/libs/simdvec/native/publish_vec_binaries.sh b/libs/simdvec/native/publish_vec_binaries.sh index 2dd2b4461b3de..0258ed5760b6b 100755 --- a/libs/simdvec/native/publish_vec_binaries.sh +++ b/libs/simdvec/native/publish_vec_binaries.sh @@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then exit 1; fi -VERSION="1.0.12" +VERSION="1.0.13" ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}" TEMP=$(mktemp -d) diff --git a/libs/simdvec/native/src/vec/c/amd64/vec.c b/libs/simdvec/native/src/vec/c/amd64/vec.c index 24a648747541c..c6b9154b60660 100644 --- a/libs/simdvec/native/src/vec/c/amd64/vec.c +++ b/libs/simdvec/native/src/vec/c/amd64/vec.c @@ -191,8 +191,8 @@ EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { // --- single precision floats -// Horizontal add of all 8 elements in a __m256 register -static inline float horizontal_sum_avx2(__m256 v) { +// Horizontally add 8 float32 elements in a __m256 register +static inline float hsum_f32_8(const __m256 v) { // First, add the low and high 128-bit lanes __m128 low = _mm256_castps256_ps128(v); // lower 128 bits __m128 high = _mm256_extractf128_ps(v, 1); // upper 128 bits @@ -261,9 +261,9 @@ EXPORT float cosf32(const float *a, const float *b, size_t elementCount) { __m256 norm_a_total = _mm256_add_ps(_mm256_add_ps(norm_a0, norm_a1), _mm256_add_ps(norm_a2, norm_a3)); __m256 norm_b_total = _mm256_add_ps(_mm256_add_ps(norm_b0, norm_b1), _mm256_add_ps(norm_b2, norm_b3)); - float dot_result = horizontal_sum_avx2(dot_total); - float norm_a_result = horizontal_sum_avx2(norm_a_total); - float norm_b_result = horizontal_sum_avx2(norm_b_total); + float dot_result = hsum_f32_8(dot_total); + float norm_a_result = hsum_f32_8(norm_a_total); + float norm_b_result = hsum_f32_8(norm_b_total); // Handle remaining tail with scalar loop for (; i < elementCount; ++i) { @@ -302,7 +302,7 @@ EXPORT float dotf32(const float *a, const float *b, size_t elementCount) { // Combine all partial sums __m256 total_sum = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3)); - float result = horizontal_sum_avx2(total_sum); + float result = hsum_f32_8(total_sum); for (; i < elementCount; ++i) { result += a[i] * b[i]; @@ -337,7 +337,7 @@ EXPORT float sqrf32(const float *a, const float *b, size_t elementCount) { // reduce all partial sums __m256 total_sum = _mm256_add_ps(_mm256_add_ps(sum0, sum1), _mm256_add_ps(sum2, sum3)); - float result = horizontal_sum_avx2(total_sum); + float result = hsum_f32_8(total_sum); for (; i < elementCount; ++i) { float diff = a[i] - b[i]; From da8ac6e44caf2f3b7d8626d8b1c3adb26bd89f5e Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Tue, 8 Jul 2025 14:34:01 +0100 Subject: [PATCH 6/6] rework bench support for heap segs --- .../vector/Float32ScorerBenchmark.java | 26 ++++++++++++++----- .../vector/Float32ScorerBenchmarkTests.java | 12 ++++++--- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java index dbde4e92492a1..3d45e30d7abef 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmark.java @@ -23,6 +23,8 @@ import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer; import org.elasticsearch.common.logging.LogConfigurator; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.simdvec.VectorScorerFactory; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -62,6 +64,10 @@ public class Float32ScorerBenchmark { static { LogConfigurator.configureESLogging(); // native access requires logging to be initialized + if (supportsHeapSegments() == false) { + final Logger LOG = LogManager.getLogger(Float32ScorerBenchmark.class); + LOG.warn("*Query targets cannot run on " + "JDK " + Runtime.version()); + } } @Param({ "96", "768", "1024" }) @@ -120,14 +126,16 @@ public void setup() throws IOException { nativeSqrScorer.setScoringOrdinal(0); // setup for getFloat32VectorScorer / query vector scoring - float[] queryVec = new float[dims]; - for (int i = 0; i < dims; i++) { - queryVec[i] = ThreadLocalRandom.current().nextFloat(); + if (supportsHeapSegments()) { + float[] queryVec = new float[dims]; + for (int i = 0; i < dims; i++) { + queryVec[i] = ThreadLocalRandom.current().nextFloat(); + } + luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); + nativeDotScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); + luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); + nativeSqrScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); } - luceneDotScorerQuery = luceneScorer(values, VectorSimilarityFunction.DOT_PRODUCT, queryVec); - nativeDotScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); - luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); - nativeSqrScorerQuery = factory.getFloat32VectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); } @TearDown @@ -188,6 +196,10 @@ public float squareDistanceNativeQuery() throws IOException { return nativeSqrScorerQuery.score(1) + nativeSqrScorerQuery.score(2); } + static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } + static float dotProductScalarImpl(float[] vec1, float[] vec2) { float dot = 0; for (int i = 0; i < vec1.length; i++) { diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java index ca262b36e41d3..cd0088b9f8a13 100644 --- a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Float32ScorerBenchmarkTests.java @@ -42,8 +42,10 @@ public void testDotProduct() throws Exception { assertEquals(expected, bench.dotProductLucene(), delta); assertEquals(expected, bench.dotProductNative(), delta); - expected = bench.dotProductLuceneQuery(); - assertEquals(expected, bench.dotProductNativeQuery(), delta); + if (Float32ScorerBenchmark.supportsHeapSegments()) { + expected = bench.dotProductLuceneQuery(); + assertEquals(expected, bench.dotProductNativeQuery(), delta); + } } finally { bench.teardown(); } @@ -60,8 +62,10 @@ public void testSquareDistance() throws Exception { assertEquals(expected, bench.squareDistanceLucene(), delta); assertEquals(expected, bench.squareDistanceNative(), delta); - expected = bench.squareDistanceLuceneQuery(); - assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + if (Float32ScorerBenchmark.supportsHeapSegments()) { + expected = bench.squareDistanceLuceneQuery(); + assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + } } finally { bench.teardown(); }