25
25
package com .oracle .svm .hosted ;
26
26
27
27
import java .lang .reflect .AccessFlag ;
28
+ import java .lang .reflect .Array ;
28
29
import java .lang .reflect .Field ;
29
30
import java .lang .reflect .InvocationTargetException ;
30
31
import java .lang .reflect .Method ;
31
32
import java .util .ArrayList ;
33
+ import java .util .Arrays ;
32
34
import java .util .Locale ;
33
35
import java .util .function .Function ;
34
36
import java .util .function .IntFunction ;
62
64
public class VectorAPIFeature implements InternalFeature {
63
65
64
66
public static final String VECTOR_API_PACKAGE_NAME = "jdk.incubator.vector" ;
67
+ public static final Class <?> PAYLOAD_CLASS = ReflectionUtil .lookupClass ("jdk.internal.vm.vector.VectorSupport$VectorPayload" );
65
68
66
69
static final Unsafe UNSAFE = Unsafe .getUnsafe ();
67
70
@@ -156,10 +159,15 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
156
159
157
160
String maxVectorName = VECTOR_API_PACKAGE_NAME + "." + elementName + "MaxVector" ;
158
161
Class <?> maxVectorClass = ReflectionUtil .lookupClass (maxVectorName );
162
+ int laneCount = VectorAPISupport .singleton ().getMaxLaneCount (vectorElement );
159
163
access .registerFieldValueTransformer (ReflectionUtil .lookupField (maxVectorClass , "VSIZE" ),
160
164
(receiver , originalValue ) -> maxVectorBits );
161
165
access .registerFieldValueTransformer (ReflectionUtil .lookupField (maxVectorClass , "VLENGTH" ),
162
- (receiver , originalValue ) -> VectorAPISupport .singleton ().getMaxLaneCount (vectorElement ));
166
+ (receiver , originalValue ) -> laneCount );
167
+ access .registerFieldValueTransformer (ReflectionUtil .lookupField (maxVectorClass , "ZERO" ),
168
+ (receiver , originalValue ) -> makeZeroVector (maxVectorClass , vectorElement , laneCount ));
169
+ access .registerFieldValueTransformer (ReflectionUtil .lookupField (maxVectorClass , "IOTA" ),
170
+ (receiver , originalValue ) -> makeIotaVector (maxVectorClass , vectorElement , laneCount ));
163
171
}
164
172
165
173
Class <?> speciesClass = ReflectionUtil .lookupClass (VECTOR_API_PACKAGE_NAME + ".AbstractSpecies" );
@@ -173,7 +181,8 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
173
181
* intrinsify operations, we may need to access information about a type before the analysis
174
182
* has seen it.
175
183
*/
176
- for (String elementName : vectorElementNames ) {
184
+ for (Class <?> vectorElement : vectorElements ) {
185
+ String elementName = vectorElement .getName ().substring (0 , 1 ).toUpperCase (Locale .ROOT ) + vectorElement .getName ().substring (1 );
177
186
for (String size : vectorSizes ) {
178
187
String baseName = elementName + size ;
179
188
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector" ;
@@ -183,6 +192,16 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
183
192
Class <?> maskClass = ReflectionUtil .lookupClass (vectorClassName + "$" + baseName + "Mask" );
184
193
UNSAFE .ensureClassInitialized (maskClass );
185
194
access .registerAsUsed (maskClass );
195
+ if (size .equals ("Max" )) {
196
+ int laneCount = VectorAPISupport .singleton ().getMaxLaneCount (vectorElement );
197
+ Class <?> shuffleElement = (vectorElement == float .class ? int .class : vectorElement == double .class ? long .class : vectorElement );
198
+ access .registerFieldValueTransformer (ReflectionUtil .lookupField (shuffleClass , "IOTA" ),
199
+ (receiver , originalValue ) -> makeIotaVector (shuffleClass , shuffleElement , laneCount ));
200
+ access .registerFieldValueTransformer (ReflectionUtil .lookupField (maskClass , "TRUE_MASK" ),
201
+ (receiver , originalValue ) -> makeNewInstanceWithBooleanPayload (maskClass , laneCount , true ));
202
+ access .registerFieldValueTransformer (ReflectionUtil .lookupField (maskClass , "FALSE_MASK" ),
203
+ (receiver , originalValue ) -> makeNewInstanceWithBooleanPayload (maskClass , laneCount , false ));
204
+ }
186
205
}
187
206
}
188
207
@@ -369,6 +388,42 @@ public static void makeConversionOperations(Class<?> conversionImplClass, Warmup
369
388
}
370
389
}
371
390
391
+ private static Object makeZeroVector (Class <?> vectorClass , Class <?> vectorElement , int laneCount ) {
392
+ Object zeroPayload = Array .newInstance (vectorElement , laneCount );
393
+ return ReflectionUtil .newInstance (ReflectionUtil .lookupConstructor (vectorClass , zeroPayload .getClass ()), zeroPayload );
394
+ }
395
+
396
+ private static Object makeNewInstanceWithBooleanPayload (Class <?> maskClass , int laneCount , boolean fillValue ) {
397
+ /*
398
+ * The constructors for Mask classes allocate new arrays based on the species length, which
399
+ * we also substitute but whose substituted value will not be used yet. So instead of just
400
+ * calling a constructor with a boolean array, we brute force this: We allocate a new
401
+ * instance which may have a payload with an incorrect length, then override its payload
402
+ * field.
403
+ */
404
+ Object newInstance = ReflectionUtil .newInstance (ReflectionUtil .lookupConstructor (maskClass , boolean .class ), true );
405
+ boolean [] payload = new boolean [laneCount ];
406
+ Arrays .fill (payload , fillValue );
407
+ ReflectionUtil .writeField (PAYLOAD_CLASS , "payload" , newInstance , payload );
408
+ return newInstance ;
409
+ }
410
+
411
+ private static Object makeIotaVector (Class <?> vectorClass , Class <?> vectorElement , int laneCount ) {
412
+ Object iotaPayload = Array .newInstance (vectorElement , laneCount );
413
+ for (int i = 0 ; i < laneCount ; i ++) {
414
+ // adapted from AbstractSpecies.iotaArray
415
+ if ((byte ) i == i ) {
416
+ Array .setByte (iotaPayload , i , (byte ) i );
417
+ } else if ((short ) i == i ) {
418
+ Array .setShort (iotaPayload , i , (short ) i );
419
+ } else {
420
+ Array .setInt (iotaPayload , i , i );
421
+ }
422
+ VMError .guarantee (Array .getDouble (iotaPayload , i ) == i , "wrong initialization of iota array: %s at %s" , Array .getDouble (iotaPayload , i ), i );
423
+ }
424
+ return ReflectionUtil .newInstance (ReflectionUtil .lookupConstructor (vectorClass , iotaPayload .getClass ()), iotaPayload );
425
+ }
426
+
372
427
@ Override
373
428
public void registerInvocationPlugins (Providers providers , GraphBuilderConfiguration .Plugins plugins , ParsingReason reason ) {
374
429
if (VectorAPIIntrinsics .intrinsificationSupported (HostedOptionValues .singleton ())) {
0 commit comments