1
+ #pragma once
2
+ #include < cstdint>
3
+ #include < array>
4
+ #include < span>
5
+
6
+ namespace FNTpp {
7
+ constexpr unsigned int Kbit = 5 ;
8
+ constexpr unsigned int K = 1 << Kbit;
9
+ constexpr int64_t P = (1ULL << K) + 1 ;
10
+ constexpr int64_t wordmask = (1ULL << K) - 1 ;
11
+
12
+ template <uint8_t bit>
13
+ uint32_t BitReverse (uint32_t in)
14
+ {
15
+ if constexpr (bit > 1 ) {
16
+ const uint32_t center = in & ((bit & 1 ) << (bit / 2 ));
17
+ return (BitReverse<bit / 2 >(in & ((1U << (bit / 2 )) - 1 ))
18
+ << (bit + 1 ) / 2 ) |
19
+ center | BitReverse<bit / 2 >(in >> ((bit + 1 ) / 2 ));
20
+ }
21
+ else {
22
+ return in;
23
+ }
24
+ }
25
+
26
+ static inline int64_t ModLshift (int64_t a, uint8_t b)
27
+ {
28
+ // If b >= 32, multiply by 2^32 ≡ -1 (mod P).
29
+ // => a = P - a (unless a == 0), then reduce b by 32.
30
+ if (b >= 32 ) {
31
+ if (a != 0 ) {
32
+ a = P - a;
33
+ }
34
+ b -= 32 ; // now b < 32
35
+ }
36
+
37
+ // Shift by b < 32 in 64-bit arithmetic (safe from overflow).
38
+ int64_t r = a << b;
39
+
40
+ // Now reduce a modulo P:
41
+ // hi = upper 32 bits
42
+ // lo = lower 32 bits
43
+ // Since (hi << 32) ≡ -hi (mod P),
44
+ // we can do (lo + hi) mod P and then subtract P if needed.
45
+ const int64_t hi = r >> K;
46
+ const int64_t lo = r & wordmask;
47
+ r = -hi + lo;
48
+
49
+ // Subtract P once or twice if needed to ensure a < P
50
+ if (r < 0 ) r += P;
51
+ if (r >= P) r -= P;
52
+ return r;
53
+ }
54
+
55
+ template <uint8_t Nbit>
56
+ inline void MulInvN (std::array<int64_t , 1u <<Nbit>& a){
57
+ for (int i = 0 ; i < (1u <<Nbit); i++) a[i] = ModLshift (a[i], 2 *K-Nbit);
58
+ }
59
+
60
+ template <unsigned int Nbit>
61
+ void FNT (const std::span<int64_t , 1u << Nbit> res)
62
+ {
63
+ if constexpr (Nbit == 1 ){
64
+ const int64_t temp = res[0 ];
65
+ res[0 ] += res[1 ];
66
+ if (res[0 ] >= P) res[0 ] -= P;
67
+ res[1 ] = temp - res[1 ];
68
+ if (res[1 ] < 0 ) res[1 ] += P;
69
+ }else {
70
+ constexpr unsigned int N = 1u << Nbit;
71
+ constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
72
+ for (unsigned int i = 0 ; i < N/2 ; i++){
73
+ const int64_t temp = res[i]+res[i+N/2 ];
74
+ res[i+N/2 ] = res[i]-res[i+N/2 ];
75
+ if (res[i+N/2 ] < 0 ) res[i+N/2 ] += P;
76
+ if (i!=0 ) res[i+N/2 ] = ModLshift (res[i+N/2 ],i*stride);
77
+ res[i] = temp >= P ? temp - P : temp;
78
+ }
79
+ FNT<Nbit-1 >(res.template subspan <0 ,N/2 >());
80
+ FNT<Nbit-1 >(res.template subspan <N/2 ,N/2 >());
81
+ }
82
+ }
83
+
84
+ template <unsigned int Nbit>
85
+ void IFNT (const std::span<int64_t , 1u << Nbit> res)
86
+ {
87
+ if constexpr (Nbit == 1 ){
88
+ const int64_t temp = res[0 ];
89
+ res[0 ] += res[1 ];
90
+ if (res[0 ] >= P) res[0 ] -= P;
91
+ res[1 ] = temp - res[1 ];
92
+ if (res[1 ] < 0 ) res[1 ] += P;
93
+ }else {
94
+ constexpr unsigned int N = 1u << Nbit;
95
+ IFNT<Nbit-1 >(res.template subspan <0 ,N/2 >());
96
+ IFNT<Nbit-1 >(res.template subspan <N/2 ,N/2 >());
97
+ constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
98
+ for (unsigned int i = 0 ; i < N/2 ; i++){
99
+ if (i!=0 ) res[i+N/2 ] = ModLshift (res[i+N/2 ],(N-i)*stride);
100
+ const int64_t temp = res[i]+res[i+N/2 ];
101
+ res[i+N/2 ] = res[i]-res[i+N/2 ]; // Part of twiddle factor
102
+ if (res[i+N/2 ] < 0 ) res[i+N/2 ] += P;
103
+ res[i] = temp >= P ? temp - P : temp;
104
+ }
105
+ }
106
+ }
107
+
108
+
109
+ template <unsigned int Nbit>
110
+ void TwistFNT (std::array<int64_t , 1u << (Nbit+1 )> &res, const std::array<int64_t, 1u << Nbit> &a)
111
+ {
112
+ constexpr unsigned int formersizebit = (Nbit + 1 )/2 ;
113
+ static_assert (formersizebit <= Kbit, " sizebit must be less than or equal to Kbit" );
114
+ constexpr unsigned int formersize = 1u << formersizebit;
115
+ constexpr unsigned int latersizebit = (Nbit + 1 ) - formersizebit;
116
+ constexpr unsigned int latersize = 1u << latersizebit;
117
+ constexpr unsigned int formerrbit = (Kbit+1 ) - formersizebit;
118
+ constexpr unsigned int laterrbit = (Kbit+1 ) - latersizebit;
119
+ // Former
120
+ for (unsigned int i = 0 ; i < latersize/2 ; i++){
121
+ std::array<int64_t , formersize> temp;
122
+ for (unsigned int j = 0 ; j < formersize; j++)
123
+ temp[j] = ModLshift (a[j*(latersize/2 ) + i],(j*(latersize/2 ) + i)<<(formerrbit-1 ));
124
+ FNT<formersizebit>(std::span{temp});
125
+ for (unsigned int j = 0 ; j < formersize; j++)
126
+ res[j*latersize + i] = temp[j];
127
+ }
128
+ // Later
129
+ for (unsigned int i = 0 ; i < formersize; i++)
130
+ FNT<latersizebit>(std::span{res}.subspan (i,latersize));
131
+ }
132
+ }
0 commit comments