Skip to content

Commit 8451c7b

Browse files
committed
use u32x8 in many_facts
1 parent ee69a9c commit 8451c7b

File tree

2 files changed

+21
-37
lines changed

2 files changed

+21
-37
lines changed

cp-algo/util/simd.hpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,29 +56,13 @@ namespace cp_algo {
5656
#endif
5757
}
5858

59-
u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
59+
[[gnu::always_inline]] inline u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
6060
auto x0246 = u64x4(x) & uint32_t(-1);
6161
auto y0246 = u64x4(y) & uint32_t(-1);
6262
auto x1357 = u64x4(x) >> 32;
6363
auto y1357 = u64x4(y) >> 32;
64-
#ifdef __AVX2__
65-
auto xy0246 = u64x4(_mm256_mul_epu32(__m256i(x0246), __m256i(y0246)));
66-
auto xy1357 = u64x4(_mm256_mul_epu32(__m256i(x1357), __m256i(y1357)));
67-
#else
68-
u64x4 xy0246 = x0246 * y0246;
69-
u64x4 xy1357 = x1357 * y1357;
70-
#endif
71-
auto xy_inv = u64x4(u32x8(xy0246 | (xy1357 << 32)) * (u32x8() + imod));
72-
auto xy_inv0246 = xy_inv & uint32_t(-1);
73-
auto xy_inv1357 = xy_inv >> 32;
74-
#ifdef __AVX2__
75-
xy0246 += u64x4(_mm256_mul_epu32(__m256i(xy_inv0246), __m256i() + mod));
76-
xy1357 += u64x4(_mm256_mul_epu32(__m256i(xy_inv1357), __m256i() + mod));
77-
#else
78-
xy0246 += xy_inv0246 * mod;
79-
xy1357 += xy_inv1357 * mod;
80-
#endif
81-
return u32x8((xy0246 >> 32) | (xy1357 & -1ULL << 32));
64+
return u32x8(montgomery_mul(x0246, y0246, mod, imod)) |
65+
u32x8(montgomery_mul(x1357, y1357, mod, imod) << 32);
8266
}
8367

8468
[[gnu::always_inline]] inline dx4 rotate_right(dx4 x) {

verify/simd/many_facts.test.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ using namespace std;
1010
using namespace cp_algo;
1111

1212
constexpr int mod = 998244353;
13-
constexpr auto mod4 = u64x4() + mod;
14-
constexpr auto imod4 = u64x4() - math::inv2(mod);
13+
constexpr int imod = -math::inv2(mod);
1514

1615
void facts_inplace(vector<int> &args) {
1716
constexpr int block = 1 << 16;
@@ -26,39 +25,40 @@ void facts_inplace(vector<int> &args) {
2625
args_per_block[(mod - x - 1) / block].push_back(i);
2726
}
2827
}
29-
uint64_t b2x32 = (1ULL << 32) % mod;
28+
uint32_t b2x32 = (1ULL << 32) % mod;
3029
uint64_t fact = 1;
31-
const int K = 4;
32-
for(uint64_t b = 0; b <= limit; b += K * block) {
33-
u64x4 cur[K];
34-
static array<u64x4, block / 4> prods[K];
35-
for(int z = 0; z < K; z++) {
36-
for(int j = 0; j < 4; j++) {
37-
cur[z][j] = b + z * block + j * block / 4;
30+
const int accum = 4;
31+
const int simd_size = 8;
32+
for(uint64_t b = 0; b <= limit; b += accum * block) {
33+
u32x8 cur[accum];
34+
static array<u32x8, block / simd_size> prods[accum];
35+
for(int z = 0; z < accum; z++) {
36+
for(int j = 0; j < simd_size; j++) {
37+
cur[z][j] = uint32_t(b + z * block + j * (block / simd_size));
3838
prods[z][0][j] = cur[z][j] + !(b || z || j);
3939
#pragma GCC diagnostic push
4040
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
41-
cur[z][j] = cur[z][j] * b2x32 % mod;
41+
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
4242
#pragma GCC diagnostic pop
4343
}
4444
}
45-
for(int i = 1; i < block / 4; i++) {
46-
for(int z = 0; z < K; z++) {
45+
for(int i = 1; i < block / simd_size; i++) {
46+
for(int z = 0; z < accum; z++) {
4747
cur[z] += b2x32;
4848
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
49-
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod4, imod4);
49+
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
5050
}
5151
}
52-
for(int z = 0; z < K; z++) {
52+
for(int z = 0; z < accum; z++) {
5353
uint64_t bl = b + z * block;
5454
for(auto i: args_per_block[bl / block]) {
5555
size_t x = args[i];
5656
if(x >= mod / 2) {
5757
x = mod - x - 1;
5858
}
5959
x -= bl;
60-
auto pre_blocks = x / (block / 4);
61-
auto in_block = x % (block / 4);
60+
auto pre_blocks = x / (block / simd_size);
61+
auto in_block = x % (block / simd_size);
6262
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
6363
for(size_t j = 0; j < pre_blocks; j++) {
6464
ans = ans * prods[z].back()[j] % mod;
@@ -71,7 +71,7 @@ void facts_inplace(vector<int> &args) {
7171
}
7272
}
7373
args_per_block[bl / block].clear();
74-
for(int j = 0; j < 4; j++) {
74+
for(int j = 0; j < simd_size; j++) {
7575
fact = fact * prods[z].back()[j] % mod;
7676
}
7777
}

0 commit comments

Comments
 (0)