|
3 | 3 | * Date: 2019-04-16
|
4 | 4 | * License: CC0
|
5 | 5 | * Source: based on KACTL's FFT
|
6 |
| - * Description: Can be used for convolutions modulo specific nice primes |
7 |
| - * of the form $2^a b+1$, where the convolution result has size at most $2^a$. |
| 6 | + * Description: ntt(a) computes $\hat f(k) = \sum_x a[x] g^{xk}$ for all $k$, where $g=\text{root}^{(mod-1)/N}$. |
| 7 | + * N must be a power of 2. |
| 8 | + * Useful for convolution modulo specific nice primes of the form $2^a b+1$, |
| 9 | + * where the convolution result has size at most $2^a$. For arbitrary modulo, see FFTMod. |
| 10 | + \texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$. |
| 11 | + For manual convolution: NTT the inputs, multiply |
| 12 | + pointwise, divide by n, reverse(start+1, end), NTT back. |
8 | 13 | * Inputs must be in [0, mod).
|
9 | 14 | * Time: O(N \log N)
|
10 | 15 | * Status: stress-tested
|
|
16 | 21 | const ll mod = (119 << 23) + 1, root = 62; // = 998244353
|
17 | 22 | // For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21
|
18 | 23 | // and 483 << 21 (same root). The last two are > 10^9.
|
19 |
| - |
20 | 24 | typedef vector<ll> vl;
|
21 |
| -void ntt(vl& a, vl& rt, vl& rev, int n) { |
| 25 | +void ntt(vl &a) { |
| 26 | + int n = sz(a), L = 31 - __builtin_clz(n); |
| 27 | + static vl rt(2, 1); |
| 28 | + for (static int k = 2, s = 2; k < n; k *= 2, s++) { |
| 29 | + rt.resize(n); |
| 30 | + ll z[] = {1, modpow(root, mod >> s)}; |
| 31 | + rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod; |
| 32 | + } |
| 33 | + vi rev(n); |
| 34 | + rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2; |
22 | 35 | rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
|
23 | 36 | for (int k = 1; k < n; k *= 2)
|
24 | 37 | for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
|
25 |
| - ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j]; |
26 |
| - a[i + j + k] = (z > ai ? ai - z + mod : ai - z); |
27 |
| - ai += (ai + z >= mod ? z - mod : z); |
28 |
| - } |
| 38 | + ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j]; |
| 39 | + a[i + j + k] = ai - z + (z > ai ? mod : 0); |
| 40 | + ai += (ai + z >= mod ? z - mod : z); |
| 41 | + } |
29 | 42 | }
|
30 |
| - |
31 |
| -vl conv(const vl& a, const vl& b) { |
32 |
| - if (a.empty() || b.empty()) |
33 |
| - return {}; |
34 |
| - int s = sz(a)+sz(b)-1, B = 32 - __builtin_clz(s), n = 1 << B; |
35 |
| - vl L(a), R(b), out(n), rt(n, 1), rev(n); |
| 43 | +vl conv(const vl &a, const vl &b) { |
| 44 | + if (a.empty() || b.empty()) return {}; |
| 45 | + int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), n = 1 << B; |
| 46 | + int inv = modpow(n, mod - 2); |
| 47 | + vl L(a), R(b), out(n); |
36 | 48 | L.resize(n), R.resize(n);
|
37 |
| - rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << B) / 2; |
38 |
| - ll curL = mod / 2, inv = modpow(n, mod - 2); |
39 |
| - for (int k = 2; k < n; k *= 2) { |
40 |
| - ll z[] = {1, modpow(root, curL /= 2)}; |
41 |
| - rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod; |
42 |
| - } |
43 |
| - ntt(L, rt, rev, n); ntt(R, rt, rev, n); |
44 |
| - rep(i,0,n) out[-i & (n-1)] = L[i] * R[i] % mod * inv % mod; |
45 |
| - ntt(out, rt, rev, n); |
| 49 | + ntt(L), ntt(R); |
| 50 | + rep(i,0,n) out[-i & (n - 1)] = (ll)L[i] * R[i] % mod * inv % mod; |
| 51 | + ntt(out); |
46 | 52 | return {out.begin(), out.begin() + s};
|
47 | 53 | }
|
0 commit comments