15
15
import org .junit .AfterClass ;
16
16
import org .junit .BeforeClass ;
17
17
18
- import java .lang .foreign .Arena ;
19
18
import java .lang .foreign .MemorySegment ;
20
- import java .util .stream .IntStream ;
21
19
22
20
import static org .hamcrest .Matchers .containsString ;
23
21
24
- public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
22
+ public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {
25
23
26
24
// bounds of the range of values that can be seen by int7 scalar quantized vectors
27
25
static final byte MIN_INT7_VALUE = 0 ;
28
26
static final byte MAX_INT7_VALUE = 127 ;
29
27
30
- static final Class <IllegalArgumentException > IAE = IllegalArgumentException .class ;
31
- static final Class <IndexOutOfBoundsException > IOOBE = IndexOutOfBoundsException .class ;
32
-
33
- static final int [] VECTOR_DIMS = { 1 , 4 , 6 , 8 , 13 , 16 , 25 , 31 , 32 , 33 , 64 , 100 , 128 , 207 , 256 , 300 , 512 , 702 , 1023 , 1024 , 1025 };
34
-
35
- final int size ;
36
-
37
- static Arena arena ;
38
-
39
- final double delta ;
40
-
41
- public JDKVectorLibraryTests (int size ) {
42
- this .size = size ;
43
- this .delta = 1e-5 * size ; // scale the delta with the size
28
+ public JDKVectorLibraryInt7uTests (int size ) {
29
+ super (size );
44
30
}
45
31
46
32
@ BeforeClass
47
- public static void setup () {
48
- arena = Arena . ofConfined ();
33
+ public static void beforeClass () {
34
+ VectorSimilarityFunctionsTests . setup ();
49
35
}
50
36
51
37
@ AfterClass
52
- public static void cleanup () {
53
- arena . close ();
38
+ public static void afterClass () {
39
+ VectorSimilarityFunctionsTests . cleanup ();
54
40
}
55
41
56
42
@ ParametersFactory
57
43
public static Iterable <Object []> parametersFactory () {
58
- return () -> IntStream . of ( VECTOR_DIMS ). boxed (). map ( i -> new Object [] { i }). iterator ();
44
+ return VectorSimilarityFunctionsTests . parametersFactory ();
59
45
}
60
46
61
47
public void testInt7BinaryVectors () {
@@ -79,7 +65,7 @@ public void testInt7BinaryVectors() {
79
65
// dot product
80
66
int expected = dotProductScalar (values [first ], values [second ]);
81
67
assertEquals (expected , dotProduct7u (nativeSeg1 , nativeSeg2 , dims ));
82
- if (testWithHeapSegments ()) {
68
+ if (supportsHeapSegments ()) {
83
69
var heapSeg1 = MemorySegment .ofArray (values [first ]);
84
70
var heapSeg2 = MemorySegment .ofArray (values [second ]);
85
71
assertEquals (expected , dotProduct7u (heapSeg1 , heapSeg2 , dims ));
@@ -90,7 +76,7 @@ public void testInt7BinaryVectors() {
90
76
// square distance
91
77
expected = squareDistanceScalar (values [first ], values [second ]);
92
78
assertEquals (expected , squareDistance7u (nativeSeg1 , nativeSeg2 , dims ));
93
- if (testWithHeapSegments ()) {
79
+ if (supportsHeapSegments ()) {
94
80
var heapSeg1 = MemorySegment .ofArray (values [first ]);
95
81
var heapSeg2 = MemorySegment .ofArray (values [second ]);
96
82
assertEquals (expected , squareDistance7u (heapSeg1 , heapSeg2 , dims ));
@@ -100,10 +86,6 @@ public void testInt7BinaryVectors() {
100
86
}
101
87
}
102
88
103
- static boolean testWithHeapSegments () {
104
- return Runtime .version ().feature () >= 22 ;
105
- }
106
-
107
89
public void testIllegalDims () {
108
90
assumeTrue (notSupportedMsg (), supported ());
109
91
var segment = arena .allocate ((long ) size * 3 );
0 commit comments