Skip to content

Commit a66369c

Browse files
committed
fix modswitch error criteria and first FNT code
1 parent cef10bf commit a66369c

File tree

4 files changed

+186
-2
lines changed

4 files changed

+186
-2
lines changed

include/fnt.hpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#pragma once
2+
#include <cstdint>
3+
#include <array>
4+
#include <span>
5+
6+
namespace FNTpp{
7+
constexpr unsigned int Kbit = 5;
8+
constexpr unsigned int K = 1 << Kbit;
9+
constexpr int64_t P = (1ULL << K) + 1;
10+
constexpr int64_t wordmask = (1ULL << K) - 1;
11+
12+
template <uint8_t bit>
13+
uint32_t BitReverse(uint32_t in)
14+
{
15+
if constexpr (bit > 1) {
16+
const uint32_t center = in & ((bit & 1) << (bit / 2));
17+
return (BitReverse<bit / 2>(in & ((1U << (bit / 2)) - 1))
18+
<< (bit + 1) / 2) |
19+
center | BitReverse<bit / 2>(in >> ((bit + 1) / 2));
20+
}
21+
else {
22+
return in;
23+
}
24+
}
25+
26+
static inline int64_t ModLshift(int64_t a, uint8_t b)
27+
{
28+
// If b >= 32, multiply by 2^32 ≡ -1 (mod P).
29+
// => a = P - a (unless a == 0), then reduce b by 32.
30+
if (b >= 32) {
31+
if (a != 0) {
32+
a = P - a;
33+
}
34+
b -= 32; // now b < 32
35+
}
36+
37+
// Shift by b < 32 in 64-bit arithmetic (safe from overflow).
38+
int64_t r = a << b;
39+
40+
// Now reduce a modulo P:
41+
// hi = upper 32 bits
42+
// lo = lower 32 bits
43+
// Since (hi << 32) ≡ -hi (mod P),
44+
// we can do (lo + hi) mod P and then subtract P if needed.
45+
const int64_t hi = r >> K;
46+
const int64_t lo = r & wordmask;
47+
r = -hi + lo;
48+
49+
// Subtract P once or twice if needed to ensure a < P
50+
if (r < 0) r += P;
51+
if (r >= P) r -= P;
52+
return r;
53+
}
54+
55+
template <uint8_t Nbit>
56+
inline void MulInvN(std::array<int64_t, 1u<<Nbit>& a){
57+
for(int i = 0; i < (1u<<Nbit); i++) a[i] = ModLshift(a[i], 2*K-Nbit);
58+
}
59+
60+
template <unsigned int Nbit>
61+
void FNT(const std::span<int64_t, 1u << Nbit> res)
62+
{
63+
if constexpr (Nbit == 1){
64+
const int64_t temp = res[0];
65+
res[0] += res[1];
66+
if(res[0] >= P) res[0] -= P;
67+
res[1] = temp - res[1];
68+
if(res[1] < 0) res[1] += P;
69+
}else{
70+
constexpr unsigned int N = 1u << Nbit;
71+
constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
72+
for(unsigned int i = 0; i < N/2; i++){
73+
const int64_t temp = res[i]+res[i+N/2];
74+
res[i+N/2] = res[i]-res[i+N/2];
75+
if(res[i+N/2] < 0) res[i+N/2] += P;
76+
if(i!=0) res[i+N/2] = ModLshift(res[i+N/2],i*stride);
77+
res[i] = temp >= P ? temp - P : temp;
78+
}
79+
FNT<Nbit-1>(res.template subspan<0,N/2>());
80+
FNT<Nbit-1>(res.template subspan<N/2,N/2>());
81+
}
82+
}
83+
84+
template <unsigned int Nbit>
85+
void IFNT(const std::span<int64_t, 1u << Nbit> res)
86+
{
87+
if constexpr (Nbit == 1){
88+
const int64_t temp = res[0];
89+
res[0] += res[1];
90+
if(res[0] >= P) res[0] -= P;
91+
res[1] = temp - res[1];
92+
if(res[1] < 0) res[1] += P;
93+
}else{
94+
constexpr unsigned int N = 1u << Nbit;
95+
IFNT<Nbit-1>(res.template subspan<0,N/2>());
96+
IFNT<Nbit-1>(res.template subspan<N/2,N/2>());
97+
constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
98+
for(unsigned int i = 0; i < N/2; i++){
99+
if(i!=0) res[i+N/2] = ModLshift(res[i+N/2],(N-i)*stride);
100+
const int64_t temp = res[i]+res[i+N/2];
101+
res[i+N/2] = res[i]-res[i+N/2]; //Part of twiddle factor
102+
if(res[i+N/2] < 0) res[i+N/2] += P;
103+
res[i] = temp >= P ? temp - P : temp;
104+
}
105+
}
106+
}
107+
108+
109+
template <unsigned int Nbit>
110+
void TwistFNT(std::array<int64_t, 1u << (Nbit+1)> &res, const std::array<int64_t, 1u << Nbit> &a)
111+
{
112+
constexpr unsigned int formersizebit = (Nbit + 1)/2;
113+
static_assert(formersizebit <= Kbit, "sizebit must be less than or equal to Kbit");
114+
constexpr unsigned int formersize = 1u << formersizebit;
115+
constexpr unsigned int latersizebit = (Nbit + 1) - formersizebit;
116+
constexpr unsigned int latersize = 1u << latersizebit;
117+
constexpr unsigned int formerrbit = (Kbit+1) - formersizebit;
118+
constexpr unsigned int laterrbit = (Kbit+1) - latersizebit;
119+
//Former
120+
for(unsigned int i = 0; i < latersize/2; i++){
121+
std::array<int64_t, formersize> temp;
122+
for(unsigned int j = 0; j < formersize; j++)
123+
temp[j] = ModLshift(a[j*(latersize/2) + i],(j*(latersize/2) + i)<<(formerrbit-1));
124+
FNT<formersizebit>(std::span{temp});
125+
for(unsigned int j = 0; j < formersize; j++)
126+
res[j*latersize + i] = temp[j];
127+
}
128+
//Later
129+
for(unsigned int i = 0; i < formersize; i++)
130+
FNT<latersizebit>(std::span{res}.subspan(i,latersize));
131+
}
132+
}

include/raintt.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ constexpr T ipow(T num, unsigned int pow)
1313
: pow == 0 ? 1
1414
: num * ipow(num, pow - 1);
1515
}
16+
#ifdef USE_COMPRESS
17+
constexpr uint min_wordbits = 27;
18+
#else
19+
constexpr uint min_wordbits = 31;
20+
#endif
1621

1722
#ifdef __clang__
1823
// Currently _BigInt is only implemented in clang

test/fnt.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include<fnt.hpp>
2+
#include<random>
3+
#include<array>
4+
#include<iostream>
5+
#include<cassert>
6+
7+
int main(){
8+
constexpr uint32_t num_test = 1000;
9+
constexpr unsigned int Nbit = 6;
10+
constexpr unsigned int N = 1u << Nbit;
11+
12+
std::random_device seed_gen;
13+
std::default_random_engine engine(seed_gen());
14+
std::uniform_int_distribution<int64_t> Pdist(0, FNTpp::P);
15+
16+
std::cout<< "Start ModLShift Test"<< std::endl;
17+
for(int test = 0; test < num_test; test++){
18+
const int64_t a = Pdist(engine);
19+
const uint shift = std::uniform_int_distribution<uint>(0, 63)(engine);
20+
const int64_t res = FNTpp::ModLshift(a, shift);
21+
if(res != (static_cast<__int128_t>(a) << shift) % FNTpp::P)
22+
std::cout << "a: " << a << " shift: " << shift << " res: " << res << " expected: " << static_cast<int64_t>((static_cast<__int128_t>(a) << shift) % FNTpp::P) << std::endl;
23+
assert(res == (static_cast<__int128_t>(a) << shift) % FNTpp::P);
24+
}
25+
std::cout<< "Passed ModLShift"<< std::endl;
26+
27+
std::cout << "invN Test" << std::endl;
28+
assert(1 == FNTpp::ModLshift(N, 2*FNTpp::K-Nbit));
29+
std::cout << "Passed invN" << std::endl;
30+
31+
std::cout << "Start univariable FNT only test." << std::endl;
32+
for(int test = 0; test < num_test; test++){
33+
std::array<int64_t, N> a;
34+
for(int i = 0; i < N; i++) a[i] = Pdist(engine);
35+
std::array<int64_t, N> res;
36+
res = a;
37+
FNTpp::FNT<Nbit>(res);
38+
FNTpp::IFNT<Nbit>(res);
39+
FNTpp::MulInvN<Nbit>(res);
40+
for(int i = 0; i < N; i++)
41+
if(a[i] != res[i]) std::cout << "i: "<< i << " a: " << a[i] << " res: " << res[i] << std::endl;
42+
for(int i = 0; i < N; i++) assert(a[i] == res[i]);
43+
}
44+
std::cout << "Univariable FNT only test Passed" << std::endl;
45+
46+
return 0;
47+
}

test/raintt.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ int main()
194194
(1ULL << (raintt::wordbits - 1 - 1))) >>
195195
(raintt::wordbits - 1);
196196
assert(std::abs(static_cast<int>(a - c)) <=
197-
(1 << (32 - raintt::wordbits + 1)));
197+
((1U << raintt::min_wordbits)/raintt::P)+1);
198198
}
199199
std::cout << "Modswitch Passed" << std::endl;
200200
for (int test = 0; test < num_test; test++) {
@@ -215,7 +215,7 @@ int main()
215215
// 4)std::cout<<res[i]<<":"<<a[i]<<std::endl;
216216
for (int i = 0; i < TFHEpp::lvl1param::n; i++)
217217
assert(std::abs(static_cast<int>(res[i] - a[i])) <=
218-
(1 << (32 - raintt::wordbits + 1)));
218+
((1U << raintt::min_wordbits)/raintt::P)+1);
219219
}
220220
std::cout << "NTT with modswitch Passed" << std::endl;
221221

0 commit comments

Comments
 (0)