Skip to content

Commit 61bc640

Browse files
authored
Updated NTT to share interface with FFT (#167)
1 parent 0897eb9 commit 61bc640

File tree

3 files changed

+49
-34
lines changed

3 files changed

+49
-34
lines changed

content/numerical/FastFourierTransform.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
* License: CC0
55
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent)
66
Accuracy bound from http://www.daemonology.net/papers/fft.pdf
7-
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
7+
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. N must be a power of 2.
8+
Useful for convolution:
89
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
910
For convolution of complex numbers or more than two vectors: FFT, multiply
1011
pointwise, divide by n, reverse(start+1, end), FFT back.
1112
Rounding is safe if $(\sum a_i^2 + \sum b_i^2)\log_2{N} < 9\cdot10^{14}$
1213
(in practice $10^{16}$; higher for random inputs).
13-
Otherwise, use long doubles/NTT/FFTMod.
14+
Otherwise, use NTT/FFTMod.
1415
* Time: O(N \log N) with $N = |A|+|B|$ ($\tilde 1s$ for $N=2^{22}$)
1516
* Status: somewhat tested
1617
* Details: An in-depth examination of precision for both FFT and FFTMod can be found

content/numerical/NumberTheoreticTransform.h

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@
33
* Date: 2019-04-16
44
* License: CC0
55
* Source: based on KACTL's FFT
6-
* Description: Can be used for convolutions modulo specific nice primes
7-
* of the form $2^a b+1$, where the convolution result has size at most $2^a$.
6+
* Description: ntt(a) computes $\hat f(k) = \sum_x a[x] g^{xk}$ for all $k$, where $g=\text{root}^{(mod-1)/N}$.
7+
* N must be a power of 2.
8+
* Useful for convolution modulo specific nice primes of the form $2^a b+1$,
9+
* where the convolution result has size at most $2^a$. For arbitrary modulo, see FFTMod.
10+
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
11+
For manual convolution: NTT the inputs, multiply
12+
pointwise, divide by n, reverse(start+1, end), NTT back.
813
* Inputs must be in [0, mod).
914
* Time: O(N \log N)
1015
* Status: stress-tested
@@ -16,32 +21,33 @@
1621
const ll mod = (119 << 23) + 1, root = 62; // = 998244353
1722
// For p < 2^30 there is also e.g. 5 << 25, 7 << 26, 479 << 21
1823
// and 483 << 21 (same root). The last two are > 10^9.
19-
2024
typedef vector<ll> vl;
21-
void ntt(vl& a, vl& rt, vl& rev, int n) {
25+
void ntt(vl &a) {
26+
int n = sz(a), L = 31 - __builtin_clz(n);
27+
static vl rt(2, 1);
28+
for (static int k = 2, s = 2; k < n; k *= 2, s++) {
29+
rt.resize(n);
30+
ll z[] = {1, modpow(root, mod >> s)};
31+
rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
32+
}
33+
vi rev(n);
34+
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
2235
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
2336
for (int k = 1; k < n; k *= 2)
2437
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
25-
ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
26-
a[i + j + k] = (z > ai ? ai - z + mod : ai - z);
27-
ai += (ai + z >= mod ? z - mod : z);
28-
}
38+
ll z = rt[j + k] * a[i + j + k] % mod, &ai = a[i + j];
39+
a[i + j + k] = ai - z + (z > ai ? mod : 0);
40+
ai += (ai + z >= mod ? z - mod : z);
41+
}
2942
}
30-
31-
vl conv(const vl& a, const vl& b) {
32-
if (a.empty() || b.empty())
33-
return {};
34-
int s = sz(a)+sz(b)-1, B = 32 - __builtin_clz(s), n = 1 << B;
35-
vl L(a), R(b), out(n), rt(n, 1), rev(n);
43+
vl conv(const vl &a, const vl &b) {
44+
if (a.empty() || b.empty()) return {};
45+
int s = sz(a) + sz(b) - 1, B = 32 - __builtin_clz(s), n = 1 << B;
46+
int inv = modpow(n, mod - 2);
47+
vl L(a), R(b), out(n);
3648
L.resize(n), R.resize(n);
37-
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << B) / 2;
38-
ll curL = mod / 2, inv = modpow(n, mod - 2);
39-
for (int k = 2; k < n; k *= 2) {
40-
ll z[] = {1, modpow(root, curL /= 2)};
41-
rep(i,k,2*k) rt[i] = rt[i / 2] * z[i & 1] % mod;
42-
}
43-
ntt(L, rt, rev, n); ntt(R, rt, rev, n);
44-
rep(i,0,n) out[-i & (n-1)] = L[i] * R[i] % mod * inv % mod;
45-
ntt(out, rt, rev, n);
49+
ntt(L), ntt(R);
50+
rep(i,0,n) out[-i & (n - 1)] = (ll)L[i] * R[i] % mod * inv % mod;
51+
ntt(out);
4652
return {out.begin(), out.begin() + s};
4753
}

stress-tests/numerical/NumberTheoreticTransform.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ namespace ignore {
77
ll modpow(ll a, ll e);
88
#include "../../content/numerical/NumberTheoreticTransform.h"
99
ll modpow(ll a, ll e) {
10-
if (e == 0) return 1;
11-
ll x = modpow(a * a % mod, e >> 1);
12-
return e & 1 ? x * a % mod : x;
10+
if (e == 0)
11+
return 1;
12+
ll x = modpow(a * a % mod, e >> 1);
13+
return e & 1 ? x * a % mod : x;
1314
}
1415

15-
1616
vl simpleConv(vl a, vl b) {
1717
int s = sz(a) + sz(b) - 1;
1818
if (a.empty() || b.empty()) return {};
@@ -24,11 +24,11 @@ vl simpleConv(vl a, vl b) {
2424
}
2525

2626
int ra() {
27-
static unsigned X;
28-
X *= 123671231;
29-
X += 1238713;
30-
X ^= 1237618;
31-
return (X >> 1);
27+
static unsigned X;
28+
X *= 123671231;
29+
X += 1238713;
30+
X ^= 1237618;
31+
return (X >> 1);
3232
}
3333

3434
int main() {
@@ -42,6 +42,14 @@ int main() {
4242
for(auto &x: b) x = (ra() % 100 - 50+mod)%mod;
4343
for(auto &x: simpleConv(a, b)) res += (ll)x * ind++ % mod;
4444
for(auto &x: conv(a, b)) res2 += (ll)x * ind2++ % mod;
45+
a.resize(16);
46+
vl a2 = a;
47+
ntt(a2);
48+
rep(k, 0, sz(a2)) {
49+
ll sum = 0;
50+
rep(x, 0, sz(a2)) { sum = (sum + a[x] * modpow(root, k * x * (mod - 1) / sz(a))) % mod; }
51+
assert(sum == a2[k]);
52+
}
4553
}
4654
assert(res==res2);
4755
cout<<"Tests passed!"<<endl;

0 commit comments

Comments
 (0)