Skip to content

Commit 2d769d4

Browse files
Update XLA (#1614)
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent d9364bb commit 2d769d4

26 files changed

+365
-329
lines changed

.github/workflows/ci.yml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ jobs:
1212
fail-fast: false
1313
matrix:
1414
working_directory: ["nx", "exla", "torchx"]
15-
elixir: ["1.15.4", "1.16.2"]
16-
otp: ["25.3"]
15+
elixir: ["1.15.8", "1.18.4"]
1716
include:
18-
- elixir: "1.16.2"
17+
- elixir: "1.15.8"
18+
otp: "25.3"
19+
- elixir: "1.18.4"
20+
otp: "27.3"
1921
lint: true
2022
defaults:
2123
run:
@@ -57,8 +59,9 @@ jobs:
5759
fail-fast: false
5860
matrix:
5961
working_directory: ["nx", "torchx"]
60-
elixir: ["1.16.2"]
61-
otp: ["25.2"]
62+
include:
63+
- elixir: "1.18.4"
64+
otp: "27.3"
6265
defaults:
6366
run:
6467
working-directory: ${{ matrix.working_directory }}

exla/Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
2525
# Note that XLA requires c++17, Fine as well
2626
CFLAGS += -fPIC -I$(ERTS_INCLUDE_DIR) -I$(FINE_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \
2727
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
28-
-std=c++17 -w -DLLVM_VERSION_STRING=
28+
-std=c++17 -w
2929

3030
ifdef DEBUG
3131
CFLAGS += -g
3232
else
3333
CFLAGS += -O3
3434
endif
3535

36-
NVCC := $(CXX)
36+
NVCC = $(CXX)
3737
NVCCFLAGS = $(CFLAGS)
3838
LDFLAGS += -L$(XLA_EXTENSION_LIB) -lxla_extension -shared -fvisibility=hidden
3939

@@ -48,8 +48,8 @@ $(info EXLA_CPU_ONLY is not set, checking for nvcc availability)
4848

4949
ifeq ($(NVCC_TEST),nvcc)
5050
$(info CUDA is available.)
51-
NVCC := nvcc
52-
NVCCFLAGS += -DCUDA_ENABLED
51+
NVCC = nvcc
52+
NVCCFLAGS = -Xcompiler "$(CFLAGS)" -DCUDA_ENABLED
5353
else
5454
$(info CUDA is not available.)
5555
endif
@@ -82,7 +82,7 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
8282
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
8383
fi
8484

85-
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/ipc.cc
85+
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/ipc.cc
8686
SOURCES += $(wildcard $(EXLA_DIR)/custom_calls/*.cc)
8787
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls/qr.h $(EXLA_DIR)/custom_calls/eigh.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/ipc.h
8888
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o

exla/c_src/exla/custom_calls.cc

Lines changed: 0 additions & 23 deletions
This file was deleted.

exla/c_src/exla/custom_calls/eigh.h

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
#pragma once
22

3-
#include "Eigen/Eigenvalues"
4-
53
#include <algorithm>
64
#include <iostream>
75
#include <numeric>
86
#include <vector>
97

8+
#include "Eigen/Eigenvalues"
9+
#include "xla/ffi/api/ffi.h"
10+
#include "xla/ffi/ffi_api.h"
11+
12+
namespace ffi = xla::ffi;
13+
1014
template <typename DataType>
1115
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
1216
DataType *eigenvectors_out,
@@ -55,51 +59,32 @@ void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out,
5559
m * n * sizeof(DataType));
5660
}
5761

58-
template <typename DataType>
59-
void eigh_cpu_custom_call(void *out[], const void *in[]) {
60-
DataType *operand = (DataType *)in[0];
61-
62-
uint64_t *dim_sizes = (uint64_t *)in[1];
63-
uint64_t num_operand_dims = dim_sizes[0];
64-
uint64_t num_eigenvalues_dims = dim_sizes[1];
65-
uint64_t num_eigenvectors_dims = dim_sizes[2];
66-
67-
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
68-
std::vector<uint64_t> operand_dims(operand_dims_ptr,
69-
operand_dims_ptr + num_operand_dims);
70-
71-
uint64_t *eigenvalues_dims_ptr = (uint64_t *)in[3];
72-
std::vector<uint64_t> eigenvalues_dims(
73-
eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);
74-
75-
uint64_t *eigenvectors_dims_ptr = (uint64_t *)in[4];
76-
std::vector<uint64_t> eigenvectors_dims(
77-
eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);
62+
template <typename DataType, typename BufferType>
63+
ffi::Error eigh_cpu_custom_call_impl(BufferType operand,
64+
ffi::Result<BufferType> eigenvalues,
65+
ffi::Result<BufferType> eigenvectors) {
66+
auto operand_dims = operand.dimensions();
67+
auto eigenvalues_dims = eigenvalues->dimensions();
68+
auto eigenvectors_dims = eigenvectors->dimensions();
7869

7970
uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
8071
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];
8172

82-
auto leading_dimensions =
83-
std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
84-
8573
uint64_t batch_items = 1;
86-
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
87-
batch_items *= leading_dimensions[i];
74+
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
75+
batch_items *= *it;
8876
}
8977

90-
DataType *eigenvalues = (DataType *)out[0];
91-
DataType *eigenvectors = (DataType *)out[1];
92-
9378
uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
94-
uint64_t eigenvectors_stride =
95-
eigenvectors_dims[eigenvectors_dims.size() - 1] *
96-
eigenvectors_dims[eigenvectors_dims.size() - 2];
79+
uint64_t eigenvectors_stride = m * n;
9780
uint64_t inner_stride = m * n;
9881

9982
for (uint64_t i = 0; i < batch_items; i++) {
10083
single_matrix_eigh_cpu_custom_call<DataType>(
101-
eigenvalues + i * eigenvalues_stride,
102-
eigenvectors + i * eigenvectors_stride, operand + i * inner_stride, m,
103-
n);
84+
eigenvalues->typed_data() + i * eigenvalues_stride,
85+
eigenvectors->typed_data() + i * eigenvectors_stride,
86+
operand.typed_data() + i * inner_stride, m, n);
10487
}
105-
}
88+
89+
return ffi::Error::Success();
90+
}
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
#include "eigh.h"
22

3-
void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
4-
eigh_cpu_custom_call<float>(out, in);
3+
ffi::Error
4+
eigh_cpu_custom_call_f32_impl(ffi::Buffer<ffi::F32> operand,
5+
ffi::ResultBuffer<ffi::F32> eigenvalues,
6+
ffi::ResultBuffer<ffi::F32> eigenvectors) {
7+
return eigh_cpu_custom_call_impl<float, ffi::Buffer<ffi::F32>>(
8+
operand, eigenvalues, eigenvectors);
59
}
10+
11+
XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f32,
12+
eigh_cpu_custom_call_f32_impl,
13+
ffi::Ffi::Bind()
14+
.Arg<ffi::Buffer<ffi::F32>>()
15+
.Ret<ffi::Buffer<ffi::F32>>()
16+
.Ret<ffi::Buffer<ffi::F32>>());
17+
18+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f32",
19+
"Host", eigh_cpu_custom_call_f32);
Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
#include "eigh.h"
22

3-
void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
4-
eigh_cpu_custom_call<double>(out, in);
3+
ffi::Error
4+
eigh_cpu_custom_call_f64_impl(ffi::Buffer<ffi::F64> operand,
5+
ffi::ResultBuffer<ffi::F64> eigenvalues,
6+
ffi::ResultBuffer<ffi::F64> eigenvectors) {
7+
return eigh_cpu_custom_call_impl<double, ffi::Buffer<ffi::F64>>(
8+
operand, eigenvalues, eigenvectors);
59
}
10+
11+
XLA_FFI_DEFINE_HANDLER_SYMBOL(eigh_cpu_custom_call_f64,
12+
eigh_cpu_custom_call_f64_impl,
13+
ffi::Ffi::Bind()
14+
.Arg<ffi::Buffer<ffi::F64>>()
15+
.Ret<ffi::Buffer<ffi::F64>>()
16+
.Ret<ffi::Buffer<ffi::F64>>());
17+
18+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eigh_cpu_custom_call_f64",
19+
"Host", eigh_cpu_custom_call_f64);

exla/c_src/exla/custom_calls/lu.h

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
#pragma once
22

3-
#include "Eigen/LU";
3+
#include <algorithm>
4+
#include <iostream>
5+
#include <numeric>
6+
#include <vector>
7+
8+
#include "Eigen/LU"
9+
#include "xla/ffi/api/ffi.h"
10+
#include "xla/ffi/ffi_api.h"
11+
12+
namespace ffi = xla::ffi;
413

514
template <typename DataType>
6-
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) {
7-
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;
15+
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out,
16+
DataType *u_out, DataType *in,
17+
uint64_t n) {
18+
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
19+
Eigen::RowMajor>
20+
RowMajorMatrix;
821

922
Eigen::Map<RowMajorMatrix> input(in, n, n);
1023
Eigen::PartialPivLU<RowMajorMatrix> lu = input.partialPivLu();
1124

1225
// Get the permutation matrix P and convert to indices
13-
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P = lu.permutationP();
26+
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P =
27+
lu.permutationP();
1428
for (uint64_t i = 0; i < n; i++) {
1529
for (uint64_t j = 0; j < n; j++) {
1630
p_out[i * n + j] = static_cast<uint8_t>(P.indices()[i] == j ? 1 : 0);
@@ -24,7 +38,6 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType
2438
// Copy L matrix
2539
for (uint64_t i = 0; i < n; i++) {
2640
for (uint64_t j = 0; j < n; j++) {
27-
2841
if (j < i) {
2942
l_out[i * n + j] = static_cast<DataType>(L(i, j));
3043
} else if (j == i) {
@@ -47,49 +60,28 @@ void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType
4760
}
4861
}
4962

50-
template <typename DataType>
51-
void lu_cpu_custom_call(void *out[], const void *in[]) {
52-
DataType *operand = (DataType *)in[0];
53-
54-
uint64_t *dim_sizes = (uint64_t *)in[1];
55-
uint64_t num_operand_dims = dim_sizes[0];
56-
uint64_t num_p_dims = dim_sizes[1];
57-
uint64_t num_l_dims = dim_sizes[2];
58-
uint64_t num_u_dims = dim_sizes[3];
59-
60-
uint64_t *operand_dims_ptr = (uint64_t *)in[2];
61-
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);
62-
63-
uint64_t *p_dims_ptr = (uint64_t *)in[3];
64-
std::vector<uint64_t> p_dims(p_dims_ptr, p_dims_ptr + num_p_dims);
65-
66-
uint64_t *l_dims_ptr = (uint64_t *)in[4];
67-
std::vector<uint64_t> l_dims(l_dims_ptr, l_dims_ptr + num_l_dims);
68-
69-
uint64_t *u_dims_ptr = (uint64_t *)in[5];
70-
std::vector<uint64_t> u_dims(u_dims_ptr, u_dims_ptr + num_u_dims);
71-
63+
template <typename DataType, typename BufferType>
64+
ffi::Error
65+
lu_cpu_custom_call_impl(BufferType operand, ffi::Result<ffi::Buffer<ffi::U8>> p,
66+
ffi::Result<BufferType> l, ffi::Result<BufferType> u) {
67+
auto operand_dims = operand.dimensions();
68+
auto l_dims = l->dimensions();
7269
uint64_t n = l_dims[l_dims.size() - 1];
7370

74-
auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);
75-
7671
uint64_t batch_items = 1;
77-
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
78-
batch_items *= leading_dimensions[i];
72+
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
73+
batch_items *= *it;
7974
}
8075

81-
uint8_t *p = (uint8_t *)out[0];
82-
DataType *l = (DataType *)out[1];
83-
DataType *u = (DataType *)out[2];
84-
8576
uint64_t stride = n * n;
8677

8778
for (uint64_t i = 0; i < batch_items; i++) {
8879
single_matrix_lu_cpu_custom_call<DataType>(
89-
p + i * stride,
90-
l + i * stride,
91-
u + i * stride,
92-
operand + i * stride,
93-
n);
80+
p->typed_data() + i * stride,
81+
(DataType *)l->untyped_data() + i * stride,
82+
(DataType *)u->untyped_data() + i * stride,
83+
(DataType *)operand.untyped_data() + i * stride, n);
9484
}
95-
}
85+
86+
return ffi::Error::Success();
87+
}
Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1-
#include "lu.h"
21
#include "../exla_types.h"
2+
#include "lu.h"
33

4-
void lu_cpu_custom_call_bf16(void *out[], const void *in[]) {
5-
lu_cpu_custom_call<exla::bfloat16>(out, in);
4+
ffi::Error lu_cpu_custom_call_bf16_impl(ffi::Buffer<ffi::BF16> operand,
5+
ffi::ResultBuffer<ffi::U8> p,
6+
ffi::ResultBuffer<ffi::BF16> l,
7+
ffi::ResultBuffer<ffi::BF16> u) {
8+
return lu_cpu_custom_call_impl<exla::bfloat16, ffi::Buffer<ffi::BF16>>(
9+
operand, p, l, u);
610
}
11+
12+
XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_bf16,
13+
lu_cpu_custom_call_bf16_impl,
14+
ffi::Ffi::Bind()
15+
.Arg<ffi::Buffer<ffi::BF16>>()
16+
.Ret<ffi::Buffer<ffi::U8>>()
17+
.Ret<ffi::Buffer<ffi::BF16>>()
18+
.Ret<ffi::Buffer<ffi::BF16>>());
19+
20+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_bf16", "Host",
21+
lu_cpu_custom_call_bf16);
Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
1-
#include "lu.h"
21
#include "../exla_types.h"
2+
#include "lu.h"
33

4-
void lu_cpu_custom_call_f16(void *out[], const void *in[]) {
5-
lu_cpu_custom_call<exla::float16>(out, in);
4+
ffi::Error lu_cpu_custom_call_f16_impl(ffi::Buffer<ffi::F16> operand,
5+
ffi::ResultBuffer<ffi::U8> p,
6+
ffi::ResultBuffer<ffi::F16> l,
7+
ffi::ResultBuffer<ffi::F16> u) {
8+
return lu_cpu_custom_call_impl<exla::float16, ffi::Buffer<ffi::F16>>(operand,
9+
p, l, u);
610
}
11+
12+
XLA_FFI_DEFINE_HANDLER_SYMBOL(lu_cpu_custom_call_f16,
13+
lu_cpu_custom_call_f16_impl,
14+
ffi::Ffi::Bind()
15+
.Arg<ffi::Buffer<ffi::F16>>()
16+
.Ret<ffi::Buffer<ffi::U8>>()
17+
.Ret<ffi::Buffer<ffi::F16>>()
18+
.Ret<ffi::Buffer<ffi::F16>>());
19+
20+
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "lu_cpu_custom_call_f16", "Host",
21+
lu_cpu_custom_call_f16);

0 commit comments

Comments
 (0)