Skip to content

libvec: unroll pragma and push stride down #107460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/native/libraries/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ configurations {
}

var zstdVersion = "1.5.5"
var vecVersion = "1.0.1"
var vecVersion = "1.0.2"

repositories {
exclusiveContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import java.lang.invoke.MethodType;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT;
import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle;

Expand All @@ -29,24 +28,9 @@ public final class JdkVectorLibrary implements VectorLibrary {

public JdkVectorLibrary() {}

static final MethodHandle dot8stride$mh = downcallHandle("dot8s_stride", FunctionDescriptor.of(JAVA_INT));
static final MethodHandle sqr8stride$mh = downcallHandle("sqr8s_stride", FunctionDescriptor.of(JAVA_INT));

static final MethodHandle dot8s$mh = downcallHandle("dot8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));
static final MethodHandle sqr8s$mh = downcallHandle("sqr8s", FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT));

// Stride of the native implementation - consumes this number of bytes per loop invocation.
// There must be at least this number of bytes/elements available when going native
static final int DOT_STRIDE = 32;
static final int SQR_STRIDE = 16;

static {
assert DOT_STRIDE > 0 && (DOT_STRIDE & (DOT_STRIDE - 1)) == 0 : "Not a power of two";
assert dot8Stride() == DOT_STRIDE : dot8Stride() + " != " + DOT_STRIDE;
assert SQR_STRIDE > 0 && (SQR_STRIDE & (SQR_STRIDE - 1)) == 0 : "Not a power of two";
assert sqr8Stride() == SQR_STRIDE : sqr8Stride() + " != " + SQR_STRIDE;
}

/**
* Computes the dot product of given byte vectors.
* @param a address of the first vector
Expand All @@ -61,19 +45,7 @@ static int dotProduct(MemorySegment a, MemorySegment b, int length) {
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= DOT_STRIDE) {
i += length & ~(DOT_STRIDE - 1);
res = dot8s(a, b, i);
}

// tail
for (; i < length; i++) {
res += a.get(JAVA_BYTE, i) * b.get(JAVA_BYTE, i);
}
assert i == length;
return res;
return dot8s(a, b, length);
}

/**
Expand All @@ -90,36 +62,7 @@ static int squareDistance(MemorySegment a, MemorySegment b, int length) {
if (length > a.byteSize()) {
throw new IllegalArgumentException("length: " + length + ", greater than vector dimensions: " + a.byteSize());
}
int i = 0;
int res = 0;
if (length >= SQR_STRIDE) {
i += length & ~(SQR_STRIDE - 1);
res = sqr8s(a, b, i);
}

// tail
for (; i < length; i++) {
int dist = a.get(JAVA_BYTE, i) - b.get(JAVA_BYTE, i);
res += dist * dist;
}
assert i == length;
return res;
}

private static int dot8Stride() {
try {
return (int) dot8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
}

private static int sqr8Stride() {
try {
return (int) sqr8stride$mh.invokeExact();
} catch (Throwable t) {
throw new AssertionError(t);
}
return sqr8s(a, b, length);
}

private static int dot8s(MemorySegment a, MemorySegment b, int length) {
Expand Down
8 changes: 7 additions & 1 deletion libs/vec/native/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ apply plugin: 'c'

var os = org.gradle.internal.os.OperatingSystem.current()

// To update this library run publish_vec_binaries.sh
// To update this library run publish_vec_binaries.sh ( or ./gradlew vecSharedLibrary )
// Or
// For local development, build the docker image with:
// docker build --platform linux/arm64 --progress=plain .
// Grab the image id from the console output, then, e.g.
// docker run 9c9f36564c148b275aeecc42749e7b4580ded79dcf51ff6ccc008c8861e7a979 > build/libs/vec/shared/libvec.so
//
// To run tests and benchmarks on a locally built libvec,
// 1. Temporarily comment out the download in libs/native/library/build.gradle
// libs "org.elasticsearch:vec:${vecVersion}@zip"
// 2. Copy your locally built libvec binary, e.g.
// cp libs/vec/native/build/libs/vec/shared/libvec.dylib libs/native/libraries/build/platform/darwin-aarch64/libvec.dylib
//
// Look at the disassemble:
// objdump --disassemble-symbols=_dot8s build/libs/vec/shared/libvec.dylib
// Note: symbol decoration may differ on Linux, i.e. the leading underscore is not present
Expand Down
2 changes: 1 addition & 1 deletion libs/vec/native/publish_vec_binaries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
exit 1;
fi

VERSION="1.0.1"
VERSION="1.0.2"
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
TEMP=$(mktemp -d)

Expand Down
41 changes: 31 additions & 10 deletions libs/vec/native/src/vec/c/vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,7 @@
#define SQR8S_STRIDE_BYTES_LEN 16
#endif

EXPORT int dot8s_stride() {
return DOT8_STRIDE_BYTES_LEN;
}

EXPORT int sqr8s_stride() {
return SQR8S_STRIDE_BYTES_LEN;
}

EXPORT int32_t dot8s(int8_t* a, int8_t* b, size_t dims) {
int32_t dot8s_inner(int8_t* a, int8_t* b, size_t dims) {
// We have contention in the instruction pipeline on the accumulation
// registers if we use too few.
int32x4_t acc1 = vdupq_n_s32(0);
Expand All @@ -35,6 +27,7 @@ EXPORT int32_t dot8s(int8_t* a, int8_t* b, size_t dims) {
int32x4_t acc4 = vdupq_n_s32(0);

// Some unrolling gives around 50% performance improvement.
#pragma clang loop unroll_count(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it might be worth tweaking the comment a bit, i.e. accumulating into multiple registers gives around 50%, and unroll directive gives around 5%. I think otherwise it is ambiguous.

for (int i = 0; i < dims; i += DOT8_STRIDE_BYTES_LEN) {
// Read into 16 x 8 bit vectors.
int8x16_t va1 = vld1q_s8(a + i);
Expand All @@ -60,12 +53,26 @@ EXPORT int32_t dot8s(int8_t* a, int8_t* b, size_t dims) {
return vaddvq_s32(vaddq_s32(acc5, acc6));
}

EXPORT int32_t sqr8s(int8_t *a, int8_t *b, size_t dims) {
EXPORT int32_t dot8s(int8_t* a, int8_t* b, size_t dims) {
int32_t res = 0;
int i = 0;
if (dims > DOT8_STRIDE_BYTES_LEN) {
i += dims & ~(DOT8_STRIDE_BYTES_LEN - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if DOT8_STRIDE_BYTES_LEN is a power of 2. Of course it will be. Perhaps it is worth enforcing this with a static_assert somewhere so people don't accidentally break it though, i.e. static_assert((1031 & ~(DOT8_STRIDE_BYTES_LEN - 1)) == (1031 - 1031 % DOT8_STRIDE_BYTES_LEN), "Invalid DOT8_STRIDE_BYTES_LEN must be a power of 2");. Note this can be anywhere in the source file, so you don't need to put in any actual function definition (although it shouldn't actually generate any code).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++
I think you can even make this a compile time error (not sure if you need to switch to c++ for that though)

res = dot8s_inner(a, b, i);
}
for (; i < dims; i++) {
res += a[i] * b[i];
}
return res;
}

int32_t sqr8s_inner(int8_t *a, int8_t *b, size_t dims) {
int32x4_t acc1 = vdupq_n_s32(0);
int32x4_t acc2 = vdupq_n_s32(0);
int32x4_t acc3 = vdupq_n_s32(0);
int32x4_t acc4 = vdupq_n_s32(0);

#pragma clang loop unroll_count(2)
for (int i = 0; i < dims; i += SQR8S_STRIDE_BYTES_LEN) {
int8x16_t va1 = vld1q_s8(a + i);
int8x16_t vb1 = vld1q_s8(b + i);
Expand All @@ -84,3 +91,17 @@ EXPORT int32_t sqr8s(int8_t *a, int8_t *b, size_t dims) {
int32x4_t acc6 = vaddq_s32(acc3, acc4);
return vaddvq_s32(vaddq_s32(acc5, acc6));
}

EXPORT int32_t sqr8s(int8_t* a, int8_t* b, size_t dims) {
int32_t res = 0;
int i = 0;
if (i > SQR8S_STRIDE_BYTES_LEN) {
i += dims & ~(SQR8S_STRIDE_BYTES_LEN - 1);
res = sqr8s_inner(a, b, i);
}
for (; i < dims; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can try and unroll this loop too?

int32_t dist = a[i] - b[i];
res += dist * dist;
}
return res;
}
4 changes: 0 additions & 4 deletions libs/vec/native/src/vec/headers/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

#define EXPORT __attribute__((externally_visible,visibility("default")))

EXPORT int dot8s_stride();

EXPORT int sqr8s_stride();

EXPORT int32_t dot8s(int8_t* a, int8_t* b, size_t dims);

EXPORT int32_t sqr8s(int8_t *a, int8_t *b, size_t length);