Skip to content

Commit b3be665

Browse files
Chilleesimonlindholm
authored andcommitted
Added FFTMod (#67)
1 parent 5115bea commit b3be665

File tree

6 files changed

+127
-56
lines changed

6 files changed

+127
-56
lines changed

content/numerical/FastFourierTransform.h

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,32 @@
33
* Date: 2019-01-09
44
* License: CC0
55
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf (do read, it's excellent)
6-
Papers about accuracy: http://www.daemonology.net/papers/fft.pdf, http://www.cs.berkeley.edu/~fateman/papers/fftvsothers.pdf
7-
For integers rounding works if $(|a| + |b|)\max(a, b) < \mathtt{\sim} 10^9$, or in theory maybe $10^6$.
8-
* Description: fft(a, ...) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
6+
Accuracy bound from http://www.daemonology.net/papers/fft.pdf
7+
* Description: fft(a) computes $\hat f(k) = \sum_x a[x] \exp(2\pi i \cdot k x / N)$ for all $k$. Useful for convolution:
98
\texttt{conv(a, b) = c}, where $c[x] = \sum a[i]b[x-i]$.
109
For convolution of complex numbers or more than two vectors: FFT, multiply
1110
pointwise, divide by n, reverse(start+1, end), FFT back.
12-
For integers, consider using a number-theoretic transform instead, to avoid rounding issues.
13-
* Time: O(N \log N) with $N = |A|+|B|-1$ ($\tilde 1s$ for $N=2^{22}$)
11+
Rounding is safe if $(\sum a_i^2 + \sum b_i^2)\log_2{N} < 9\cdot10^{14}$
12+
(in practice $10^{16}$; higher for random inputs).
13+
Otherwise, use long doubles/NTT/FFTMod.
14+
* Time: O(N \log N) with $N = |A|+|B|$ ($\tilde 1s$ for $N=2^{22}$)
1415
* Status: somewhat tested
1516
*/
1617
#pragma once
1718

1819
typedef complex<double> C;
1920
typedef vector<double> vd;
20-
21-
void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
21+
void fft(vector<C>& a) {
22+
int n = sz(a), L = 31 - __builtin_clz(n);
23+
static vector<complex<long double>> R(2, 1);
24+
static vector<C> rt(2, 1); // (^ 10% faster if double)
25+
for (static int k = 2; k < n; k *= 2) {
26+
R.resize(n); rt.resize(n);
27+
auto x = polar(1.0L, M_PIl / k); // M_PI, lower-case L
28+
rep(i,k,2*k) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
29+
}
30+
vi rev(n);
31+
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
2232
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
2333
for (int k = 1; k < n; k *= 2)
2434
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
@@ -27,25 +37,19 @@ void fft(vector<C> &a, vector<C> &rt, vi& rev, int n) {
2737
C z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line
2838
a[i + j + k] = a[i + j] - z;
2939
a[i + j] += z;
30-
}
40+
}
3141
}
32-
3342
vd conv(const vd& a, const vd& b) {
3443
if (a.empty() || b.empty()) return {};
3544
vd res(sz(a) + sz(b) - 1);
3645
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
37-
vector<C> in(n), out(n), rt(n, 1); vi rev(n);
38-
rep(i,0,n) rev[i] = (rev[i/2] | (i&1) << L) / 2;
39-
for (int k = 2; k < n; k *= 2) {
40-
C z[] = {1, polar(1.0, M_PI / k)};
41-
rep(i,k,2*k) rt[i] = rt[i/2] * z[i&1];
42-
}
46+
vector<C> in(n), out(n);
4347
copy(all(a), begin(in));
4448
rep(i,0,sz(b)) in[i].imag(b[i]);
45-
fft(in, rt, rev, n);
49+
fft(in);
4650
trav(x, in) x *= x;
4751
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);
48-
fft(out, rt, rev, n);
49-
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4*n);
52+
fft(out);
53+
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);
5054
return res;
5155
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* Author: chilli
3+
* Date: 2019-04-25
4+
* License: CC0
5+
* Source: http://neerc.ifmo.ru/trains/toulouse/2017/fft2.pdf
6+
* Description: Higher precision FFT, can be used for convolutions modulo arbitrary integers
7+
* as long as $N\log_2N\cdot \text{mod} < 8.6 \cdot 10^{14}$ (in practice $10^{16}$ or higher).
8+
* Inputs must be in $[0, \text{mod})$.
9+
* Time: O(N \log N), where $N = |A|+|B|$ (twice as slow as NTT or FFT)
10+
* Status: somewhat tested
11+
*/
12+
#pragma once
13+
14+
#include "FastFourierTransform.h"
15+
16+
typedef vector<ll> vl;
17+
template<int M> vl convMod(const vl &a, const vl &b) {
18+
if (a.empty() || b.empty()) return {};
19+
vl res(sz(a) + sz(b) - 1);
20+
int B=32-__builtin_clz(sz(res)), n=1<<B, cut=int(sqrt(M));
21+
vector<C> L(n), R(n), outs(n), outl(n);
22+
rep(i,0,sz(a)) L[i] = C((int)a[i] / cut, (int)a[i] % cut);
23+
rep(i,0,sz(b)) R[i] = C((int)b[i] / cut, (int)b[i] % cut);
24+
fft(L), fft(R);
25+
rep(i,0,n) {
26+
int j = -i & (n - 1);
27+
outl[j] = (L[i] + conj(L[j])) * R[i] / (2.0 * n);
28+
outs[j] = (L[i] - conj(L[j])) * R[i] / (2.0 * n) / 1i;
29+
}
30+
fft(outl), fft(outs);
31+
rep(i,0,sz(res)) {
32+
ll av = ll(real(outl[i])+.5), cv = ll(imag(outs[i])+.5);
33+
ll bv = ll(imag(outl[i])+.5) + ll(real(outs[i])+.5);
34+
res[i] = ((av % M * cut + bv) % M * cut + cv) % M;
35+
}
36+
return res;
37+
}

content/numerical/NumberTheoreticTransform.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
* Source: based on KACTL's FFT
66
* Description: Can be used for convolutions modulo specific nice primes
77
* of the form $2^a b+1$, where the convolution result has size at most $2^a$.
8-
* For other primes/integers, use three different primes and combine with CRT.
98
* Inputs must be in [0, mod).
109
* Time: O(N \log N)
1110
* Status: fuzz-tested

content/numerical/chapter.tex

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ \chapter{Numerical}
1919
\kactlimport{Tridiagonal.h}
2020
\section{Fourier transforms}
2121
\kactlimport{FastFourierTransform.h}
22+
\kactlimport{FastFourierTransformMod.h}
2223
\kactlimport{NumberTheoreticTransform.h}
2324
\kactlimport{FastSubsetTransform.h}

fuzz-tests/numerical/FastFourierTransform.cpp

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,33 @@ typedef long long ll;
1111
typedef pair<int, int> pii;
1212
typedef vector<int> vi;
1313

14-
typedef valarray<complex<double> > carray;
15-
void fft(carray& x, carray& roots) {
16-
int N = sz(x);
17-
if (N <= 1) return;
18-
carray even = x[slice(0, N/2, 2)];
19-
carray odd = x[slice(1, N/2, 2)];
20-
carray rs = roots[slice(0, N/2, 2)];
21-
fft(even, rs);
22-
fft(odd, rs);
23-
rep(k,0,N/2) {
24-
auto t = roots[k] * odd[k];
25-
x[k ] = even[k] + t;
26-
x[k+N/2] = even[k] - t;
27-
}
28-
}
29-
30-
typedef vector<double> vd;
31-
vd conv(const vd& a, const vd& b) {
32-
int s = sz(a) + sz(b) - 1, L = 32-__builtin_clz(s), n = 1<<L;
33-
if (s <= 0) return {};
34-
carray av(n), bv(n), roots(n);
35-
rep(i,0,n) roots[i] = polar(1.0, -2 * M_PI * i / n);
36-
copy(all(a), begin(av)); fft(av, roots);
37-
copy(all(b), begin(bv)); fft(bv, roots);
38-
roots = roots.apply(conj);
39-
carray cv = av * bv; fft(cv, roots);
40-
vd c(s); rep(i,0,s) c[i] = cv[i].real() / n;
41-
return c;
42-
}
14+
#include "../../content/numerical/FastFourierTransform.h"
4315

4416
const double eps = 1e-8;
4517
int main() {
4618
int n = 8;
47-
carray a(n), av(n), roots(n);
48-
rep(i,0,n) a[i] = rand() % 10 - 5;
49-
rep(i,0,n) roots[i] = polar(1.0, -2 * M_PI * i / n);
50-
av = a;
51-
fft(av, roots);
19+
vector<C> a(n);
20+
rep(i,0,n) a[i] = C(rand() % 10 - 5, rand() % 10 - 5);
21+
auto aorig = a;
22+
fft(a);
5223
rep(k,0,n) {
53-
complex<double> sum{};
24+
C sum{};
5425
rep(x,0,n) {
55-
sum += a[x] * polar(1.0, -2 * M_PI * k * x / n);
26+
sum += aorig[x] * polar(1.0, 2 * M_PI * k * x / n);
27+
}
28+
assert(norm(sum - a[k]) < 1e-6);
29+
}
30+
31+
vd A(4), B(6);
32+
trav(x, A) x = rand() / (RAND_MAX + 1.0) * 10 - 5;
33+
trav(x, B) x = rand() / (RAND_MAX + 1.0) * 10 - 5;
34+
vd C = conv(A, B);
35+
rep(i,0,sz(A) + sz(B) - 1) {
36+
double sum = 0;
37+
rep(j,0,sz(A)) if (i - j >= 0 && i - j < sz(B)) {
38+
sum += A[j] * B[i - j];
5639
}
57-
assert(abs(sum-av[k]) < eps);
40+
assert(abs(sum - C[i]) < eps);
5841
}
5942
cout<<"Tests passed!"<<endl;
6043
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <bits/stdc++.h>
2+
using namespace std;
3+
4+
#define rep(i, a, b) for(int i = a; i < int(b); ++i)
5+
#define trav(a, v) for(auto& a : v)
6+
#define all(x) x.begin(), x.end()
7+
#define sz(x) (int)(x).size()
8+
9+
typedef long long ll;
10+
typedef pair<int, int> pii;
11+
typedef vector<int> vi;
12+
13+
const ll mod = 1000000007;
14+
15+
#include "../../content/numerical/FastFourierTransformMod.h"
16+
17+
vl simpleConv(vl a, vl b) {
18+
if (a.empty() || b.empty()) return {};
19+
int s = sz(a) + sz(b) - 1;
20+
vl c(s);
21+
rep(i,0,sz(a)) rep(j,0,sz(b))
22+
c[i+j] = (c[i+j] + (ll)a[i] * b[j]) % mod;
23+
trav(x, c) if (x < 0) x += mod;
24+
return c;
25+
}
26+
27+
int ra() {
28+
static unsigned X;
29+
X *= 123671231;
30+
X += 1238713;
31+
X ^= 1237618;
32+
return (X >> 1);
33+
}
34+
35+
int main() {
36+
vl a, b;
37+
rep(it,0,6000) {
38+
a.resize(ra() % 100);
39+
b.resize(ra() % 100);
40+
trav(x, a) x = ra() % mod;
41+
trav(x, b) x = ra() % mod;
42+
auto v1 = simpleConv(a, b);
43+
auto v2 = convMod<mod>(a, b);
44+
assert(v1 == v2);
45+
}
46+
cout<<"Tests passed!"<<endl;
47+
}

0 commit comments

Comments
 (0)