Skip to content

Commit 20ef955

Browse files
authored
Refactor JDKVectorLibraryTests to JDKVectorLibraryInt7uTests (elastic#130617)
This commit refactors JDKVectorLibraryTests to JDKVectorLibraryInt7uTests, in order to make space for other vector scorer benchmarks, namely float32.
1 parent e06de3e commit 20ef955

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,54 @@
99

1010
package org.elasticsearch.nativeaccess;
1111

12+
import org.elasticsearch.common.logging.LogConfigurator;
13+
import org.elasticsearch.common.logging.NodeNamePatternConverter;
1214
import org.elasticsearch.test.ESTestCase;
1315

16+
import java.lang.foreign.Arena;
17+
import java.util.Arrays;
1418
import java.util.Optional;
19+
import java.util.stream.IntStream;
1520

1621
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
1722
import static org.hamcrest.Matchers.not;
1823

19-
public class VectorSimilarityFunctionsTests extends ESTestCase {
24+
public abstract class VectorSimilarityFunctionsTests extends ESTestCase {
2025

21-
final Optional<VectorSimilarityFunctions> vectorSimilarityFunctions;
26+
static {
27+
NodeNamePatternConverter.setGlobalNodeName("foo");
28+
LogConfigurator.loadLog4jPlugins();
29+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
30+
}
31+
32+
public static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
33+
public static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
34+
35+
protected static Arena arena;
36+
37+
protected final int size;
38+
protected final Optional<VectorSimilarityFunctions> vectorSimilarityFunctions;
39+
40+
protected static Iterable<Object[]> parametersFactory() {
41+
var dims1 = Arrays.stream(new int[] { 1, 2, 4, 6, 8, 12, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 768 });
42+
var dims2 = Arrays.stream(new int[] { 1000, 1023, 1024, 1025, 2047, 2048, 2049, 4095, 4096, 4097 });
43+
return () -> IntStream.concat(dims1, dims2).boxed().map(i -> new Object[] { i }).iterator();
44+
}
2245

23-
public VectorSimilarityFunctionsTests() {
46+
protected VectorSimilarityFunctionsTests(int size) {
2447
logger.info(platformMsg());
48+
this.size = size;
2549
vectorSimilarityFunctions = NativeAccess.instance().getVectorSimilarityFunctions();
2650
}
2751

52+
public static void setup() {
53+
arena = Arena.ofConfined();
54+
}
55+
56+
public static void cleanup() {
57+
arena.close();
58+
}
59+
2860
public void testSupported() {
2961
supported();
3062
}
@@ -59,4 +91,9 @@ public static String platformMsg() {
5991
var osName = System.getProperty("os.name");
6092
return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch;
6193
}
94+
95+
// Support for passing on-heap arrays/segments to native
96+
protected static boolean supportsHeapSegments() {
97+
return Runtime.version().feature() >= 22;
98+
}
6299
}

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java renamed to libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,47 +15,33 @@
1515
import org.junit.AfterClass;
1616
import org.junit.BeforeClass;
1717

18-
import java.lang.foreign.Arena;
1918
import java.lang.foreign.MemorySegment;
20-
import java.util.stream.IntStream;
2119

2220
import static org.hamcrest.Matchers.containsString;
2321

24-
public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
22+
public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {
2523

2624
// bounds of the range of values that can be seen by int7 scalar quantized vectors
2725
static final byte MIN_INT7_VALUE = 0;
2826
static final byte MAX_INT7_VALUE = 127;
2927

30-
static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
31-
static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
32-
33-
static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };
34-
35-
final int size;
36-
37-
static Arena arena;
38-
39-
final double delta;
40-
41-
public JDKVectorLibraryTests(int size) {
42-
this.size = size;
43-
this.delta = 1e-5 * size; // scale the delta with the size
28+
public JDKVectorLibraryInt7uTests(int size) {
29+
super(size);
4430
}
4531

4632
@BeforeClass
47-
public static void setup() {
48-
arena = Arena.ofConfined();
33+
public static void beforeClass() {
34+
VectorSimilarityFunctionsTests.setup();
4935
}
5036

5137
@AfterClass
52-
public static void cleanup() {
53-
arena.close();
38+
public static void afterClass() {
39+
VectorSimilarityFunctionsTests.cleanup();
5440
}
5541

5642
@ParametersFactory
5743
public static Iterable<Object[]> parametersFactory() {
58-
return () -> IntStream.of(VECTOR_DIMS).boxed().map(i -> new Object[] { i }).iterator();
44+
return VectorSimilarityFunctionsTests.parametersFactory();
5945
}
6046

6147
public void testInt7BinaryVectors() {
@@ -79,7 +65,7 @@ public void testInt7BinaryVectors() {
7965
// dot product
8066
int expected = dotProductScalar(values[first], values[second]);
8167
assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims));
82-
if (testWithHeapSegments()) {
68+
if (supportsHeapSegments()) {
8369
var heapSeg1 = MemorySegment.ofArray(values[first]);
8470
var heapSeg2 = MemorySegment.ofArray(values[second]);
8571
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
@@ -90,7 +76,7 @@ public void testInt7BinaryVectors() {
9076
// square distance
9177
expected = squareDistanceScalar(values[first], values[second]);
9278
assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims));
93-
if (testWithHeapSegments()) {
79+
if (supportsHeapSegments()) {
9480
var heapSeg1 = MemorySegment.ofArray(values[first]);
9581
var heapSeg2 = MemorySegment.ofArray(values[second]);
9682
assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims));
@@ -100,10 +86,6 @@ public void testInt7BinaryVectors() {
10086
}
10187
}
10288

103-
static boolean testWithHeapSegments() {
104-
return Runtime.version().feature() >= 22;
105-
}
106-
10789
public void testIllegalDims() {
10890
assumeTrue(notSupportedMsg(), supported());
10991
var segment = arena.allocate((long) size * 3);

0 commit comments

Comments
 (0)