Skip to content

Commit fdc0019

Browse files
authored
Basic implementation for avx (#69)
* add baseline implementation * support onednn * kernel timing * performance * int8 onednn * utils * minor * fix * mnior * avx imp for int8 gemm * add file * fix * cuda compiler flags * compilation for int8 * minor * minor * minor * 2x2 unroll * omp imp * unroll 32 elements * min/max params * bias support * minor * minor * fix * bf32 fp32 ops * bmm * fix * rounding * fix
1 parent d412e11 commit fdc0019

File tree

11 files changed

+2168
-0
lines changed

11 files changed

+2168
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Check operating system
2+
OS := $(shell uname)
3+
4+
# OneDNN availability
5+
ONEDNN_AVAILABLE =
6+
ifeq ($(OS), Darwin) # macOS
7+
$(info Detected macOS)
8+
ONEDNN_AVAILABLE := $(shell otool -L /usr/local/lib/libdnnl* 2> /dev/null)
9+
else ifeq ($(OS), Linux) # Ubuntu or other Linux distributions
10+
$(info Detected Linux)
11+
ONEDNN_AVAILABLE_CHK := $(shell pkg-config --exists dnnl; echo $$?)
12+
ifeq ($(ONEDNN_AVAILABLE_CHK), 0)
13+
ONEDNN_AVAILABLE := $(shell pkg-config --exists onednn 2> /dev/null) # TODO: check this in Linux env
14+
endif
15+
else
16+
$(error Unsupported operating system)
17+
endif
18+
19+
# Check if CUDA is available
20+
CUDA_AVAILABLE := $(shell command -v /usr/local/cuda/bin/nvcc 2> /dev/null)
21+
22+
CC_FLAGS = -O3 -std=c++11 #-g
23+
#CC_FLAGS = -O3 -std=c++11 -Xclang -fopenmp -g
24+
# Compiler and flags
25+
ifdef CUDA_AVAILABLE
26+
CC = /usr/local/cuda/bin/nvcc
27+
CC_FLAGS += -DCUDA_ENABLE
28+
$(info CUDA is available)
29+
else
30+
CC = g++
31+
CC_FLAGS += -mavx2 -mfma
32+
endif
33+
ifdef ONEDNN_AVAILABLE
34+
CC_FLAGS += -DONEDNN_ENABLE
35+
$(info ONEDNN is available)
36+
endif
37+
38+
# Include directories
39+
# INCLUDE_DIRS = -I./ -I/usr/local/opt/libomp/include
40+
INCLUDE_DIRS = -I./
41+
42+
# Library directories
43+
LIBRARY_DIRS = -L/usr/local/cuda/lib64
44+
45+
# Library flag
46+
LDFLAGS =
47+
ifdef ONEDNN_AVAILABLE
48+
LDFLAGS += -ldnnl
49+
endif
50+
51+
# TODO: openmp flag
52+
OMP_FLAGS = -L/usr/local/opt/libomp/lib/ -lomp
53+
# LDFLAGS += $(OMP_FLAGS
54+
55+
# Files
56+
TARGET = benchmark_run
57+
CUDA_SRCS = lib/matmul.cu
58+
CPP_SRCS = benchmark/main.cc lib/matmul_imp.cc lib/utils.cc lib/matmul_int8.cc lib/matmul_avx_int8.cc
59+
ONEDNN_SRCS = lib/matmul_onednn.cc
60+
61+
# Objects
62+
OBJS = $(CPP_SRCS:.cc=.o)
63+
INT8_OBJS = $(INT8_CPP_SRCS:.cc=.o)
64+
ifdef CUDA_AVAILABLE
65+
OBJS += $(CUDA_SRCS:.cu=.o)
66+
endif
67+
ifdef ONEDNN_AVAILABLE
68+
OBJS += $(ONEDNN_SRCS:.cc=.o)
69+
INT8_OBJS += $(ONEDNN_SRCS:.cc=.o)
70+
endif
71+
72+
73+
# $(info ONEDNN_AVAILABLE: $(ONEDNN_AVAILABLE))
74+
$(info CC_FLAGS: $(CC_FLAGS))
75+
76+
77+
# Targets
78+
all: $(TARGET)
79+
80+
$(TARGET): $(OBJS)
81+
$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -o $(TARGET) $(OBJS)
82+
83+
%.o: %.cu
84+
$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@
85+
86+
ifdef CUDA_AVAILABLE
87+
%.o: %.cc
88+
$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -x cu -c $< -o $@
89+
else
90+
%.o: %.cc
91+
$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@
92+
#$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@ $(OMP_FLAGS)
93+
endif
94+
95+
clean:
96+
rm -f $(TARGET) $(OBJS)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Build onednn (enable openmp on mac)
2+
3+
cmake .. -DOpenMP_C_FLAGS="-Xclang -fopenmp -I/usr/local/opt/libomp/include" -DOpenMP_C_LIB_NAMES="libomp" -DDNNL_CPU_RUNTIME=OMP -DOpenMP_CXX_FLAGS="-Xclang -fopenmp -I/usr/local/opt/libomp/include" -DOpenMP_CXX_LIB_NAMES="libomp" -DOpenMP_libomp_LIBRARY=/usr/local/opt/libomp/lib/libomp.dylib -DCMAKE_SHARED_LINKER_FLAGS="-L/usr/local/opt/libomp/lib/ -lomp -Wl,-rpath,/usr/local/opt/libomp/lib/"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#include <math.h>
2+
#include <stdio.h>
3+
4+
#include <cstdlib>
5+
#include <iostream>
6+
7+
#include "lib/matmul.h"
8+
9+
#define BLK_SIZE 16
10+
#define MAX_PRECISION_ERROR 0.01
11+
12+
#define M 1024
13+
#define N 1024
14+
#define K 1024
15+
#define A_ROW M
16+
#define A_COLUMN K
17+
#define B_ROW K
18+
#define B_COLUMN N
19+
#define C_ROW M
20+
#define C_COLUMN N
21+
#define NUM_THREAD 16
22+
23+
float MAT_A[A_ROW * A_COLUMN];
24+
float MAT_B[B_ROW * B_COLUMN];
25+
float transpose_B[B_ROW * B_COLUMN];
26+
float native_C[C_ROW * C_COLUMN];
27+
float output_C[C_ROW * C_COLUMN];
28+
29+
int8_t MAT_A_s8[A_ROW * A_COLUMN];
30+
int8_t MAT_B_s8[B_ROW * B_COLUMN];
31+
int32_t bias_s32[C_COLUMN];
32+
int8_t transpose_B_s8[B_ROW * B_COLUMN];
33+
int8_t native_C_s8[C_ROW * C_COLUMN];
34+
int8_t output_C_s8[C_ROW * C_COLUMN];
35+
36+
bool check_identical(float matA[], float matB[], int size) {
37+
for (int i = 0; i < size; i++) {
38+
if (abs((matA[i] - matB[i]) / (matA[i])) > MAX_PRECISION_ERROR) {
39+
printf("%d: %f, %f", i, matA[i], matB[i]);
40+
return false;
41+
}
42+
}
43+
return true;
44+
}
45+
46+
bool check_identical(int8_t matA[], int8_t matB[], int size) {
47+
for (int i = 0; i < size; i++) {
48+
if (matA[i] != matB[i]) {
49+
printf("%d: %d, %d", i, matA[i], matB[i]);
50+
return false;
51+
}
52+
}
53+
return true;
54+
}
55+
56+
template <typename T>
57+
void dump_integer_array(T matA[], int size) {
58+
for (int i = 0; i < size; i++) {
59+
printf("%d,", matA[i]);
60+
}
61+
printf("\n");
62+
}
63+
64+
void initialize_matrix(float A[], int size) {
65+
for (int i = 0; i < size; i++) {
66+
A[i] = (float)(rand()) / (float)(RAND_MAX);
67+
}
68+
}
69+
70+
void initialize_matrix(int8_t A[], int size) {
71+
for (int i = 0; i < size; i++) {
72+
// A[i] = (rand() % 2) - 1;
73+
A[i] = (rand() % 2);
74+
}
75+
}
76+
77+
void initialize_matrix(int32_t A[], int size) {
78+
for (int i = 0; i < size; i++) {
79+
// A[i] = (rand() % 2) - 1;
80+
A[i] = (rand() % 2);
81+
}
82+
}
83+
84+
using namespace matmul;
85+
86+
int main() {
87+
// initialize
88+
initialize_matrix(MAT_A, A_ROW * A_COLUMN);
89+
initialize_matrix(MAT_B, B_ROW * B_COLUMN);
90+
initialize_matrix(native_C, C_ROW * C_COLUMN);
91+
92+
initialize_matrix(MAT_A_s8, A_ROW * A_COLUMN);
93+
initialize_matrix(MAT_B_s8, B_ROW * B_COLUMN);
94+
initialize_matrix(native_C_s8, C_ROW * C_COLUMN);
95+
// initialize_matrix(bias_s32, C_ROW * C_COLUMN);
96+
97+
MatmulOperator matmul_op = MatmulOperator();
98+
99+
struct matmul_params params, params_int8;
100+
params.A.row = A_ROW;
101+
params.A.column = A_COLUMN;
102+
params.A.data_ptr = MAT_A;
103+
params.B.row = B_ROW;
104+
params.B.column = B_COLUMN;
105+
params.B.data_ptr = MAT_B;
106+
params.C.row = C_ROW;
107+
params.C.column = C_COLUMN;
108+
params.opt_params.blk_size = BLK_SIZE;
109+
params.opt_params.num_thread = NUM_THREAD;
110+
111+
// int8
112+
params_int8.A.row = A_ROW;
113+
params_int8.A.column = A_COLUMN;
114+
params_int8.A.int8_data_ptr = MAT_A_s8;
115+
params_int8.A.qparams.scale = 1.0;
116+
params_int8.A.qparams.zero_point = 0;
117+
params_int8.B.row = B_ROW;
118+
params_int8.B.column = B_COLUMN;
119+
params_int8.B.int8_data_ptr = MAT_B_s8;
120+
params_int8.B.qparams.scale = 1.0;
121+
params_int8.B.qparams.zero_point = 0;
122+
params_int8.C.row = C_ROW;
123+
params_int8.C.column = C_COLUMN;
124+
params_int8.C.int8_data_ptr = native_C_s8;
125+
params_int8.C.qparams.scale = 1.0;
126+
params_int8.C.qparams.q_max = 127;
127+
params_int8.C.qparams.q_min = -128;
128+
params_int8.C.qparams.zero_point = 0;
129+
params_int8.opt_params.blk_size = BLK_SIZE;
130+
params_int8.opt_params.num_thread = NUM_THREAD;
131+
params_int8.bias.row = 1;
132+
params_int8.bias.column = C_COLUMN;
133+
params_int8.bias.int32_data_ptr = bias_s32;
134+
135+
// Baseline
136+
params.C.data_ptr = native_C;
137+
matmul_op.evaluate(MatmulOperator::NAIVE, &params);
138+
139+
params.C.data_ptr = output_C;
140+
// unrolling
141+
matmul_op.evaluate(MatmulOperator::UNROLL, &params);
142+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_unrolling\n");
143+
144+
// reordering
145+
matmul_op.evaluate(MatmulOperator::REORDER, &params);
146+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_reordering\n");
147+
148+
// tiling
149+
matmul_op.evaluate(MatmulOperator::TILING, &params);
150+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_tiling\n");
151+
152+
// multithreading
153+
matmul_op.evaluate(MatmulOperator::MULTITHREAD, &params);
154+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_multithreading\n");
155+
156+
// transpose
157+
matmul_op.evaluate(MatmulOperator::TRANSPOSE, &params);
158+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_transpose\n");
159+
160+
// transpose + simd
161+
initialize_matrix(output_C, C_ROW * C_COLUMN);
162+
matmul_op.evaluate(MatmulOperator::TRANSPOSE_SIMD, &params);
163+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_transpose_simd\n");
164+
165+
// cuda
166+
#ifdef CUDA_ENABLE
167+
matmul_op.evaluate(MatmulOperator::CUDA, &params);
168+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_cuda\n");
169+
#endif
170+
171+
// ONEDNN
172+
#ifdef ONEDNN_ENABLE
173+
initialize_matrix(output_C, C_ROW * C_COLUMN);
174+
matmul_op.evaluate(MatmulOperator::ONEDNN_FP32, &params);
175+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("\nincorrect output of mat_mul_onedenn\n");
176+
#endif
177+
178+
// For fast, we need to transpose B first
179+
for (int i = 0; i < B_COLUMN; i++)
180+
for (int j = 0; j < B_ROW; j++) transpose_B[i * B_ROW + j] = MAT_B[j * B_COLUMN + i];
181+
params.B.column = B_ROW;
182+
params.B.row = B_COLUMN;
183+
params.B.data_ptr = transpose_B;
184+
params.opt_params.blk_size = BLK_SIZE;
185+
params.opt_params.num_thread = NUM_THREAD;
186+
187+
// fast
188+
initialize_matrix(output_C, C_ROW * C_COLUMN);
189+
matmul_op.evaluate(MatmulOperator::FAST, &params);
190+
if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_fast\n");
191+
192+
// int8
193+
matmul_op.evaluate(MatmulOperator::INT8_BASELINE, &params_int8);
194+
195+
params_int8.C.int8_data_ptr = output_C_s8;
196+
197+
// For int8 SIMD, we need to transpose B first
198+
for (int i = 0; i < B_COLUMN; i++)
199+
for (int j = 0; j < B_ROW; j++) transpose_B_s8[i * B_ROW + j] = MAT_B_s8[j * B_COLUMN + i];
200+
201+
params_int8.B.int8_data_ptr = transpose_B_s8;
202+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
203+
matmul_op.evaluate(MatmulOperator::INT8_AVX, &params_int8);
204+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
205+
printf("incorrect output from mat_mul_avx_int8\n");
206+
207+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
208+
matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST, &params_int8);
209+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
210+
printf("incorrect output from mat_mul_avx_int8_fast\n");
211+
212+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
213+
matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2, &params_int8);
214+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
215+
printf("incorrect output from mat_mul_avx_int8_fast_2x2\n");
216+
217+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
218+
matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2_32UNROLL, &params_int8);
219+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
220+
printf("incorrect output from mat_mul_avx_int8_fast_2x2_32unroll\n");
221+
222+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
223+
matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2_OMP, &params_int8);
224+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
225+
printf("incorrect output from mat_mul_avx_int8_fast_2x2_omp\n");
226+
227+
// ONEDNN
228+
#ifdef ONEDNN_ENABLE
229+
initialize_matrix(output_C_s8, C_ROW * C_COLUMN);
230+
matmul_op.evaluate(MatmulOperator::ONEDNN_INT8, &params_int8);
231+
if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN))
232+
printf("incorrect output from mat_mul_onednn_int8\n");
233+
#endif
234+
// Debugging
235+
// dump_integer_array(MAT_A_s8, A_ROW * A_COLUMN);
236+
// dump_integer_array(MAT_B_s8, B_ROW * B_COLUMN);
237+
// dump_integer_array(native_C_s8, C_ROW * C_COLUMN);
238+
// dump_integer_array(output_C_s8, C_ROW * C_COLUMN);
239+
240+
return 0;
241+
}

0 commit comments

Comments
 (0)