@@ -10,8 +10,7 @@ using namespace std;
10
10
using namespace cp_algo ;
11
11
12
12
constexpr int mod = 998244353 ;
13
- constexpr auto mod4 = u64x4() + mod;
14
- constexpr auto imod4 = u64x4() - math::inv2(mod);
13
+ constexpr int imod = -math::inv2(mod);
15
14
16
15
void facts_inplace (vector<int > &args) {
17
16
constexpr int block = 1 << 16 ;
@@ -26,39 +25,40 @@ void facts_inplace(vector<int> &args) {
26
25
args_per_block[(mod - x - 1 ) / block].push_back (i);
27
26
}
28
27
}
29
- uint64_t b2x32 = (1ULL << 32 ) % mod;
28
+ uint32_t b2x32 = (1ULL << 32 ) % mod;
30
29
uint64_t fact = 1 ;
31
- const int K = 4 ;
32
- for (uint64_t b = 0 ; b <= limit; b += K * block) {
33
- u64x4 cur[K];
34
- static array<u64x4, block / 4 > prods[K];
35
- for (int z = 0 ; z < K; z++) {
36
- for (int j = 0 ; j < 4 ; j++) {
37
- cur[z][j] = b + z * block + j * block / 4 ;
30
+ const int accum = 4 ;
31
+ const int simd_size = 8 ;
32
+ for (uint64_t b = 0 ; b <= limit; b += accum * block) {
33
+ u32x8 cur[accum];
34
+ static array<u32x8, block / simd_size> prods[accum];
35
+ for (int z = 0 ; z < accum; z++) {
36
+ for (int j = 0 ; j < simd_size; j++) {
37
+ cur[z][j] = uint32_t (b + z * block + j * (block / simd_size));
38
38
prods[z][0 ][j] = cur[z][j] + !(b || z || j);
39
39
#pragma GCC diagnostic push
40
40
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
41
- cur[z][j] = cur[z][j] * b2x32 % mod;
41
+ cur[z][j] = uint32_t ( uint64_t ( cur[z][j]) * b2x32 % mod) ;
42
42
#pragma GCC diagnostic pop
43
43
}
44
44
}
45
- for (int i = 1 ; i < block / 4 ; i++) {
46
- for (int z = 0 ; z < K ; z++) {
45
+ for (int i = 1 ; i < block / simd_size ; i++) {
46
+ for (int z = 0 ; z < accum ; z++) {
47
47
cur[z] += b2x32;
48
48
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
49
- prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod4, imod4 );
49
+ prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod );
50
50
}
51
51
}
52
- for (int z = 0 ; z < K ; z++) {
52
+ for (int z = 0 ; z < accum ; z++) {
53
53
uint64_t bl = b + z * block;
54
54
for (auto i: args_per_block[bl / block]) {
55
55
size_t x = args[i];
56
56
if (x >= mod / 2 ) {
57
57
x = mod - x - 1 ;
58
58
}
59
59
x -= bl;
60
- auto pre_blocks = x / (block / 4 );
61
- auto in_block = x % (block / 4 );
60
+ auto pre_blocks = x / (block / simd_size );
61
+ auto in_block = x % (block / simd_size );
62
62
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
63
63
for (size_t j = 0 ; j < pre_blocks; j++) {
64
64
ans = ans * prods[z].back ()[j] % mod;
@@ -71,7 +71,7 @@ void facts_inplace(vector<int> &args) {
71
71
}
72
72
}
73
73
args_per_block[bl / block].clear ();
74
- for (int j = 0 ; j < 4 ; j++) {
74
+ for (int j = 0 ; j < simd_size ; j++) {
75
75
fact = fact * prods[z].back ()[j] % mod;
76
76
}
77
77
}
0 commit comments