Skip to content

Commit 21a7480

Browse files
committed
Avoid unintended use of AVX512
1 parent 1eb3d57 commit 21a7480

File tree

1 file changed

+106
-10
lines changed

1 file changed

+106
-10
lines changed

thirdparties/spqlios/fft_processor_spqlios.cpp

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,36 +38,65 @@ FFT_Processor_Spqlios::FFT_Processor_Spqlios(const int32_t N) : _2N(2 * N), N(N)
3838
}
3939

4040
void FFT_Processor_Spqlios::execute_reverse_uint(double *res, const uint32_t *a) {
41-
//for (int32_t i=0; i<N; i++) real_inout_rev[i]=(double)a[i];
41+
#ifdef USE_AVX512
4242
{
4343
double *dst = res;
44-
// double *dst = real_inout_rev;
4544
const uint32_t *ait = a;
4645
const uint32_t *aend = a + N;
46+
// __asm__ __volatile__ (
47+
// "0:\n"
48+
// "vmovupd (%1),%%xmm0\n"
49+
// "vcvtudq2pd %%xmm0,%%ymm1\n"
50+
// "vmovapd %%ymm1,(%0)\n"
51+
// "addq $16,%1\n"
52+
// "addq $32,%0\n"
53+
// "cmpq %2,%1\n"
54+
// "jb 0b\n"
55+
// : "=r"(dst), "=r"(ait), "=r"(aend)
56+
// : "0"(dst), "1"(ait), "2"(aend)
57+
// : "%xmm0", "%ymm1", "memory"
58+
// );
4759
__asm__ __volatile__ (
4860
"0:\n"
49-
"vmovupd (%1),%%xmm0\n"
50-
"vcvtudq2pd %%xmm0,%%ymm1\n"
61+
"vmovupd (%1),%%ymm0\n"
62+
"vcvtudq2pd %%ymm0,%%zmm1\n"
5163
"vmovapd %%ymm1,(%0)\n"
52-
"addq $16,%1\n"
53-
"addq $32,%0\n"
64+
"addq $32,%1\n"
65+
"addq $64,%0\n"
5466
"cmpq %2,%1\n"
5567
"jb 0b\n"
5668
: "=r"(dst), "=r"(ait), "=r"(aend)
5769
: "0"(dst), "1"(ait), "2"(aend)
58-
: "%xmm0", "%ymm1", "memory"
70+
: "%ymm0", "%zmm1", "memory"
5971
);
6072
}
73+
#else
74+
for (int32_t i=0; i<N; i++) res[i]=(double)a[i];
75+
#endif
6176
ifft(tables_reverse, res);
6277
}
6378

6479
void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
6580
//for (int32_t i=0; i<N; i++) real_inout_rev[i]=(double)a[i];
6681
{
6782
double *dst = res;
68-
// double *dst = real_inout_rev;
6983
const int32_t *ait = a;
7084
const int32_t *aend = a + N;
85+
#ifdef USE_AVX512
86+
__asm__ __volatile__ (
87+
"0:\n"
88+
"vmovdqu32 (%1),%%zmm0\n" // Load 16 int32_t values from `ait` into zmm0
89+
"vcvtdq2pd %%zmm0,%%zmm1\n" // Convert 16 int32_t values to 8 double-precision values
90+
"vmovapd %%zmm1,(%0)\n" // Store the result (8 doubles) in `dst`
91+
"addq $64,%1\n" // Increment `ait` by 64 bytes (16 int32_t values)
92+
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 double-precision values)
93+
"cmpq %2,%1\n" // Compare `ait` with `aend`
94+
"jb 0b\n" // Jump back if `ait < aend`
95+
: "=r"(dst), "=r"(ait), "=r"(aend)
96+
: "0"(dst), "1"(ait), "2"(aend)
97+
: "%zmm0", "%zmm1", "memory"
98+
);
99+
#else
71100
__asm__ __volatile__ (
72101
"0:\n"
73102
"vmovupd (%1),%%xmm0\n"
@@ -81,6 +110,7 @@ void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
81110
: "0"(dst), "1"(ait), "2"(aend)
82111
: "%xmm0", "%ymm1", "memory"
83112
);
113+
#endif
84114
}
85115
ifft(tables_reverse, res);
86116
}
@@ -110,8 +140,23 @@ void FFT_Processor_Spqlios::execute_direct_torus32(uint32_t *res, const double *
110140
double *dst = real_inout_direct;
111141
const double *sit = a;
112142
const double *send = a + N;
113-
//double __2sN = 2./N;
114143
const double *bla = &_2sN;
144+
#ifdef AVX512
145+
__asm__ __volatile__ (
146+
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
147+
"1:\n"
148+
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
149+
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
150+
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
151+
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
152+
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
153+
"cmpq %2,%1\n" // Compare `sit` with `send`
154+
"jb 1b\n" // Jump if `sit` < `send`
155+
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
156+
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
157+
: "%zmm0", "%zmm2", "memory"
158+
);
159+
#else
115160
__asm__ __volatile__ (
116161
"vbroadcastsd (%3),%%ymm2\n"
117162
"1:\n"
@@ -126,6 +171,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32(uint32_t *res, const double *
126171
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
127172
: "%ymm0", "%ymm2", "memory"
128173
);
174+
#endif
129175
}
130176
fft(tables_direct, real_inout_direct);
131177
// for (int32_t i = 0; i < N; i++) res[i] = uint32_t(int64_t(real_inout_direct[i]));
@@ -142,6 +188,22 @@ void FFT_Processor_Spqlios::execute_direct_torus32_q(uint32_t *res, const double
142188
const double *send = a + N;
143189
//double __2sN = 2./N;
144190
const double *bla = &_2sN;
191+
#ifdef USE_AVX512
192+
__asm__ __volatile__ (
193+
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
194+
"1:\n"
195+
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
196+
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
197+
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
198+
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
199+
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
200+
"cmpq %2,%1\n" // Compare `sit` with `send`
201+
"jb 1b\n" // Jump if `sit` < `send`
202+
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
203+
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
204+
: "%zmm0", "%zmm2", "memory"
205+
);
206+
#else
145207
__asm__ __volatile__ (
146208
"vbroadcastsd (%3),%%ymm2\n"
147209
"1:\n"
@@ -156,6 +218,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32_q(uint32_t *res, const double
156218
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
157219
: "%ymm0", "%ymm2", "memory"
158220
);
221+
#endif
159222
}
160223
fft(tables_direct, real_inout_direct);
161224
for (int32_t i = 0; i < N; i++) res[i] = uint32_t((int64_t(real_inout_direct[i])%q+q)%q);
@@ -169,8 +232,23 @@ void FFT_Processor_Spqlios::execute_direct_torus32_rescale(uint32_t *res, const
169232
double *dst = real_inout_direct;
170233
const double *sit = a;
171234
const double *send = a + N;
172-
//double __2sN = 2./N;
173235
const double *bla = &_2sN;
236+
#ifdef USE_AVX512
237+
__asm__ __volatile__ (
238+
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
239+
"1:\n"
240+
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
241+
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
242+
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
243+
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
244+
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
245+
"cmpq %2,%1\n" // Compare `sit` with `send`
246+
"jb 1b\n" // Jump if `sit` < `send`
247+
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
248+
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
249+
: "%zmm0", "%zmm2", "memory"
250+
);
251+
#else
174252
__asm__ __volatile__ (
175253
"vbroadcastsd (%3),%%ymm2\n"
176254
"1:\n"
@@ -185,6 +263,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32_rescale(uint32_t *res, const
185263
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
186264
: "%ymm0", "%ymm2", "memory"
187265
);
266+
#endif
188267
}
189268
fft(tables_direct, real_inout_direct);
190269
for (int32_t i = 0; i < N; i++) res[i] = static_cast<uint32_t>(int64_t(real_inout_direct[i]/Δ));
@@ -200,6 +279,22 @@ void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double*
200279
const double* send = a+N;
201280
//double __2sN = 2./N;
202281
const double* bla = &_2sN;
282+
#ifdef USE_AVX512
283+
__asm__ __volatile__ (
284+
"vbroadcastsd (%3),%%zmm2\n" // Broadcast 2sN to zmm2
285+
"1:\n"
286+
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision floats from `sit` into zmm0
287+
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply the vector by zmm2
288+
"vmovapd %%zmm0,(%0)\n" // Store the result into `dst`
289+
"addq $64,%1\n" // Increment `sit` by 64 (8 doubles * 8 bytes per double)
290+
"addq $64,%0\n" // Increment `dst` by 64 (8 doubles * 8 bytes per double)
291+
"cmpq %2,%1\n" // Compare `sit` with `send`
292+
"jb 1b\n" // Jump back if not done
293+
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
294+
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
295+
: "%zmm0", "%zmm2", "memory"
296+
);
297+
#else
203298
__asm__ __volatile__ (
204299
"vbroadcastsd (%3),%%ymm2\n"
205300
"1:\n"
@@ -214,6 +309,7 @@ void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double*
214309
: "0"(dst),"1"(sit),"2"(send),"3"(bla)
215310
: "%ymm0","%ymm2","memory"
216311
);
312+
#endif
217313
}
218314
fft(tables_direct,real_inout_direct);
219315
#ifdef USE_AVX512

0 commit comments

Comments
 (0)