Skip to content

Commit ee69a9c

Browse files
committed
update montgomery mul
1 parent 7a64f18 commit ee69a9c

File tree

2 files changed

+34
-9
lines changed

2 files changed

+34
-9
lines changed

cp-algo/math/fft.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ namespace cp_algo::math::fft {
1717
static const int splt = int(std::sqrt(base::mod())) + 1;
1818
return splt;
1919
}
20-
static u64x4 mod, imod;
20+
static uint32_t mod, imod;
2121

2222
static void init() {
2323
if(!_init) {
2424
factor = 1 + random::rng() % (base::mod() - 1);
2525
ifactor = base(1) / factor;
26-
mod = u64x4() + base::mod();
27-
imod = u64x4() + inv2(-base::mod());
26+
mod = base::mod();
27+
imod = -inv2(base::mod());
2828
_init = true;
2929
}
3030
}
@@ -199,8 +199,8 @@ namespace cp_algo::math::fft {
199199
template<modint_type base> base dft<base>::factor = 1;
200200
template<modint_type base> base dft<base>::ifactor = 1;
201201
template<modint_type base> bool dft<base>::_init = false;
202-
template<modint_type base> u64x4 dft<base>::mod = {};
203-
template<modint_type base> u64x4 dft<base>::imod = {};
202+
template<modint_type base> uint32_t dft<base>::mod = {};
203+
template<modint_type base> uint32_t dft<base>::imod = {};
204204

205205
void mul_slow(auto &a, auto const& b, size_t k) {
206206
if(std::empty(a) || std::empty(b)) {

cp-algo/util/simd.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,49 @@ namespace cp_algo {
3838
};
3939
}
4040

41-
[[gnu::always_inline]] inline u64x4 montgomery_reduce(u64x4 x, u64x4 mod, u64x4 imod) {
42-
auto x_ninv = u64x4(u32x8(x) * u32x8(imod));
41+
[[gnu::always_inline]] inline u64x4 montgomery_reduce(u64x4 x, uint32_t mod, uint32_t imod) {
42+
auto x_ninv = u64x4(u32x8(x) * (u32x8() + imod));
4343
#ifdef __AVX2__
44-
x += u64x4(_mm256_mul_epu32(__m256i(x_ninv), __m256i(mod)));
44+
x += u64x4(_mm256_mul_epu32(__m256i(x_ninv), __m256i() + mod));
4545
#else
4646
x += x_ninv * mod;
4747
#endif
4848
return x >> 32;
4949
}
5050

51-
[[gnu::always_inline]] inline u64x4 montgomery_mul(u64x4 x, u64x4 y, u64x4 mod, u64x4 imod) {
51+
[[gnu::always_inline]] inline u64x4 montgomery_mul(u64x4 x, u64x4 y, uint32_t mod, uint32_t imod) {
5252
#ifdef __AVX2__
5353
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
5454
#else
5555
return montgomery_reduce(x * y, mod, imod);
5656
#endif
5757
}
5858

59+
u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
60+
auto x0246 = u64x4(x) & uint32_t(-1);
61+
auto y0246 = u64x4(y) & uint32_t(-1);
62+
auto x1357 = u64x4(x) >> 32;
63+
auto y1357 = u64x4(y) >> 32;
64+
#ifdef __AVX2__
65+
auto xy0246 = u64x4(_mm256_mul_epu32(__m256i(x0246), __m256i(y0246)));
66+
auto xy1357 = u64x4(_mm256_mul_epu32(__m256i(x1357), __m256i(y1357)));
67+
#else
68+
u64x4 xy0246 = x0246 * y0246;
69+
u64x4 xy1357 = x1357 * y1357;
70+
#endif
71+
auto xy_inv = u64x4(u32x8(xy0246 | (xy1357 << 32)) * (u32x8() + imod));
72+
auto xy_inv0246 = xy_inv & uint32_t(-1);
73+
auto xy_inv1357 = xy_inv >> 32;
74+
#ifdef __AVX2__
75+
xy0246 += u64x4(_mm256_mul_epu32(__m256i(xy_inv0246), __m256i() + mod));
76+
xy1357 += u64x4(_mm256_mul_epu32(__m256i(xy_inv1357), __m256i() + mod));
77+
#else
78+
xy0246 += xy_inv0246 * mod;
79+
xy1357 += xy_inv1357 * mod;
80+
#endif
81+
return u32x8((xy0246 >> 32) | (xy1357 & -1ULL << 32));
82+
}
83+
5984
[[gnu::always_inline]] inline dx4 rotate_right(dx4 x) {
6085
static constexpr u64x4 shuffler = {3, 0, 1, 2};
6186
return __builtin_shuffle(x, shuffler);

0 commit comments

Comments
 (0)