Skip to content

Commit 9ab5967

Browse files
committed
[GR-66581] VectorAPIFeature: register field value transformers for cached vector instances
PullRequest: graal/21279
2 parents b6f8562 + 71c6f4f commit 9ab5967

File tree

1 file changed

+57
-2
lines changed

1 file changed

+57
-2
lines changed

substratevm/src/com.oracle.svm.hosted/src/com/oracle/svm/hosted/VectorAPIFeature.java

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
package com.oracle.svm.hosted;
2626

2727
import java.lang.reflect.AccessFlag;
28+
import java.lang.reflect.Array;
2829
import java.lang.reflect.Field;
2930
import java.lang.reflect.InvocationTargetException;
3031
import java.lang.reflect.Method;
3132
import java.util.ArrayList;
33+
import java.util.Arrays;
3234
import java.util.Locale;
3335
import java.util.function.Function;
3436
import java.util.function.IntFunction;
@@ -62,6 +64,7 @@
6264
public class VectorAPIFeature implements InternalFeature {
6365

6466
public static final String VECTOR_API_PACKAGE_NAME = "jdk.incubator.vector";
67+
public static final Class<?> PAYLOAD_CLASS = ReflectionUtil.lookupClass("jdk.internal.vm.vector.VectorSupport$VectorPayload");
6568

6669
static final Unsafe UNSAFE = Unsafe.getUnsafe();
6770

@@ -156,10 +159,15 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
156159

157160
String maxVectorName = VECTOR_API_PACKAGE_NAME + "." + elementName + "MaxVector";
158161
Class<?> maxVectorClass = ReflectionUtil.lookupClass(maxVectorName);
162+
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(vectorElement);
159163
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "VSIZE"),
160164
(receiver, originalValue) -> maxVectorBits);
161165
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "VLENGTH"),
162-
(receiver, originalValue) -> VectorAPISupport.singleton().getMaxLaneCount(vectorElement));
166+
(receiver, originalValue) -> laneCount);
167+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "ZERO"),
168+
(receiver, originalValue) -> makeZeroVector(maxVectorClass, vectorElement, laneCount));
169+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maxVectorClass, "IOTA"),
170+
(receiver, originalValue) -> makeIotaVector(maxVectorClass, vectorElement, laneCount));
163171
}
164172

165173
Class<?> speciesClass = ReflectionUtil.lookupClass(VECTOR_API_PACKAGE_NAME + ".AbstractSpecies");
@@ -173,7 +181,8 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
173181
* intrinsify operations, we may need to access information about a type before the analysis
174182
* has seen it.
175183
*/
176-
for (String elementName : vectorElementNames) {
184+
for (Class<?> vectorElement : vectorElements) {
185+
String elementName = vectorElement.getName().substring(0, 1).toUpperCase(Locale.ROOT) + vectorElement.getName().substring(1);
177186
for (String size : vectorSizes) {
178187
String baseName = elementName + size;
179188
String vectorClassName = VECTOR_API_PACKAGE_NAME + "." + baseName + "Vector";
@@ -183,6 +192,16 @@ public void beforeAnalysis(BeforeAnalysisAccess access) {
183192
Class<?> maskClass = ReflectionUtil.lookupClass(vectorClassName + "$" + baseName + "Mask");
184193
UNSAFE.ensureClassInitialized(maskClass);
185194
access.registerAsUsed(maskClass);
195+
if (size.equals("Max")) {
196+
int laneCount = VectorAPISupport.singleton().getMaxLaneCount(vectorElement);
197+
Class<?> shuffleElement = (vectorElement == float.class ? int.class : vectorElement == double.class ? long.class : vectorElement);
198+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(shuffleClass, "IOTA"),
199+
(receiver, originalValue) -> makeIotaVector(shuffleClass, shuffleElement, laneCount));
200+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maskClass, "TRUE_MASK"),
201+
(receiver, originalValue) -> makeNewInstanceWithBooleanPayload(maskClass, laneCount, true));
202+
access.registerFieldValueTransformer(ReflectionUtil.lookupField(maskClass, "FALSE_MASK"),
203+
(receiver, originalValue) -> makeNewInstanceWithBooleanPayload(maskClass, laneCount, false));
204+
}
186205
}
187206
}
188207

@@ -369,6 +388,42 @@ public static void makeConversionOperations(Class<?> conversionImplClass, Warmup
369388
}
370389
}
371390

391+
private static Object makeZeroVector(Class<?> vectorClass, Class<?> vectorElement, int laneCount) {
392+
Object zeroPayload = Array.newInstance(vectorElement, laneCount);
393+
return ReflectionUtil.newInstance(ReflectionUtil.lookupConstructor(vectorClass, zeroPayload.getClass()), zeroPayload);
394+
}
395+
396+
private static Object makeNewInstanceWithBooleanPayload(Class<?> maskClass, int laneCount, boolean fillValue) {
397+
/*
398+
* The constructors for Mask classes allocate new arrays based on the species length, which
399+
* we also substitute but whose substituted value will not be used yet. So instead of just
400+
* calling a constructor with a boolean array, we brute force this: We allocate a new
401+
* instance which may have a payload with an incorrect length, then override its payload
402+
* field.
403+
*/
404+
Object newInstance = ReflectionUtil.newInstance(ReflectionUtil.lookupConstructor(maskClass, boolean.class), true);
405+
boolean[] payload = new boolean[laneCount];
406+
Arrays.fill(payload, fillValue);
407+
ReflectionUtil.writeField(PAYLOAD_CLASS, "payload", newInstance, payload);
408+
return newInstance;
409+
}
410+
411+
private static Object makeIotaVector(Class<?> vectorClass, Class<?> vectorElement, int laneCount) {
412+
Object iotaPayload = Array.newInstance(vectorElement, laneCount);
413+
for (int i = 0; i < laneCount; i++) {
414+
// adapted from AbstractSpecies.iotaArray
415+
if ((byte) i == i) {
416+
Array.setByte(iotaPayload, i, (byte) i);
417+
} else if ((short) i == i) {
418+
Array.setShort(iotaPayload, i, (short) i);
419+
} else {
420+
Array.setInt(iotaPayload, i, i);
421+
}
422+
VMError.guarantee(Array.getDouble(iotaPayload, i) == i, "wrong initialization of iota array: %s at %s", Array.getDouble(iotaPayload, i), i);
423+
}
424+
return ReflectionUtil.newInstance(ReflectionUtil.lookupConstructor(vectorClass, iotaPayload.getClass()), iotaPayload);
425+
}
426+
372427
@Override
373428
public void registerInvocationPlugins(Providers providers, GraphBuilderConfiguration.Plugins plugins, ParsingReason reason) {
374429
if (VectorAPIIntrinsics.intrinsificationSupported(HostedOptionValues.singleton())) {

0 commit comments

Comments
 (0)