Skip to content

Commit fc81381

Browse files
committed
AESEnc working
1 parent 4c00d81 commit fc81381

File tree

5 files changed

+245
-16
lines changed

5 files changed

+245
-16
lines changed

include/aes.hpp

Lines changed: 168 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ constexpr uint Nk = 4; // Number of 32-bit words in the key
4040
constexpr uint Nb = 4; // Number of columns (32-bit words) comprising the state
4141
constexpr uint Nr = 10; // Number of rounds, which is a function of Nk and Nb
4242

43-
void KeyExpansion(std::array<uint8_t, 4*Nb*(Nr+1)> w, const std::array<uint8_t,16> key) {
43+
void KeyExpansion(std::array<uint8_t, 4*Nb*(Nr+1)>& w, const std::array<uint8_t,16>& key) {
4444
std::array<uint8_t,4> temp;
4545

4646
unsigned int i = 0;
@@ -126,14 +126,26 @@ void AESSboxROM(std::array<TLWE<typename brP::targetP>,8> &res,
126126
}
127127

128128
template <class iksP, class brP>
129-
void SubBytes(std::array<TLWE<typename brP::targetP>, 128> &res,
130-
const std::array<TLWE<typename iksP::domainP>, 128> &tlwe,
129+
void SubBytes(std::array<TLWE<typename brP::targetP>, 128>& state,
131130
const EvalKey &ek)
132131
{
133132
for(int i = 0; i < 16; i++)
134-
AESSboxROM<iksP, brP>(std::span(res).subspan(i*8).template first<8>(), std::span(tlwe).subspan(i*8).template first<8>(), ek);
133+
AESSboxROM<iksP, brP>(std::span(state).subspan(i*8).template first<8>(), std::span(state).subspan(i*8).template first<8>(), ek);
135134
}
136135

136+
template <class iksP, class brP, class cbiksP, class cbbrP>
137+
void SubBytes(std::array<TLWE<typename brP::targetP>, 128>& state,
138+
const EvalKey &ek)
139+
{
140+
for(int i = 0; i < 16; i++){
141+
std::array<TLWE<typename cbbrP::targetP>, 8> temp;
142+
AESSboxROM<iksP, cbbrP>(std::span(temp), std::span(state).subspan(i*8).template first<8>(), ek);
143+
for(int j = 0; j < 8; j++)
144+
GateBootstrapping<cbiksP, brP, 1ULL << (std::numeric_limits<typename brP::targetP::T>::digits - 2)>(state[i*8+j], temp[j], ek);
145+
}
146+
}
147+
148+
137149
template <class P>
138150
inline Polynomial<P> AESInvSboxPoly(const uint8_t upperindex)
139151
{
@@ -357,6 +369,121 @@ void MixColumn(std::array<TLWE<P>, 32>& y_out, const std::array<TLWE<P>, 32>& x)
357369
TLWEAdd<P>(y_out[28], x[20], t[59]);
358370
}
359371

372+
// https://eprint.iacr.org/2024/1076
373+
template <class P>
374+
void MixColumnDepth4(std::array<TLWE<P>, 32>& y, const std::array<TLWE<P>, 32>& x){
375+
// r0 … r64
376+
std::array<TLWE<P>, 65> r;
377+
378+
// --- first stage --------------------------------------------------------
379+
TLWEAdd<P>(r[ 0], x[23], x[31]);
380+
TLWEAdd<P>(r[ 1], x[21], x[29]);
381+
TLWEAdd<P>(r[ 2], x[17], x[25]);
382+
TLWEAdd<P>(r[ 3], x[16], x[24]);
383+
TLWEAdd<P>(r[ 4], x[15], x[23]);
384+
TLWEAdd<P>(r[ 5], x[14], x[22]);
385+
TLWEAdd<P>(r[ 6], x[12], x[20]);
386+
TLWEAdd<P>(r[ 7], x[12], x[13]);
387+
TLWEAdd<P>(r[ 8], x[11], x[20]);
388+
TLWEAdd<P>(r[ 9], x[10], x[25]);
389+
TLWEAdd<P>(r[10], x[10], x[18]);
390+
TLWEAdd<P>(r[11], x[ 9], x[18]);
391+
TLWEAdd<P>(r[12], x[ 7], x[31]);
392+
TLWEAdd<P>(r[13], x[ 7], x[15]);
393+
TLWEAdd<P>(r[14], x[ 6], x[31]);
394+
TLWEAdd<P>(r[15], x[ 6], x[30]);
395+
TLWEAdd<P>(r[16], x[ 5], x[13]);
396+
TLWEAdd<P>(r[17], x[ 5], x[ 6]);
397+
TLWEAdd<P>(r[18], x[ 4], x[28]);
398+
TLWEAdd<P>(r[19], x[ 3], x[27]);
399+
TLWEAdd<P>(r[20], x[ 3], x[11]);
400+
TLWEAdd<P>(r[21], x[ 2], x[26]);
401+
TLWEAdd<P>(r[22], x[ 1], x[ 9]);
402+
TLWEAdd<P>(r[23], x[ 0], x[ 8]);
403+
404+
// --- second stage -------------------------------------------------------
405+
TLWEAdd<P>(r[24], r[ 0], x[27]);
406+
TLWEAdd<P>(r[25], r[ 0], x[ 7]);
407+
TLWEAdd<P>(r[26], r[ 1], x[ 5]);
408+
TLWEAdd<P>(r[27], r[ 1], x[ 4]);
409+
TLWEAdd<P>(r[28], r[ 2], x[ 1]);
410+
TLWEAdd<P>(r[29], r[ 3], x[ 8]);
411+
TLWEAdd<P>(r[30], r[ 3], r[ 0]);
412+
TLWEAdd<P>(r[31], r[ 4], x[14]);
413+
TLWEAdd<P>(r[32], r[ 4], x[ 7]);
414+
TLWEAdd<P>(r[33], r[ 4], x[ 0]);
415+
TLWEAdd<P>(r[34], r[ 5], x[29]);
416+
TLWEAdd<P>(r[35], r[ 6], x[28]);
417+
TLWEAdd<P>(r[36], r[ 6], x[ 4]);
418+
TLWEAdd<P>(r[37], r[10], x[26]);
419+
TLWEAdd<P>(r[38], r[10], r[ 4]);
420+
TLWEAdd<P>(r[39], r[13], r[ 2]);
421+
TLWEAdd<P>(r[40], r[15], x[22]);
422+
TLWEAdd<P>(r[41], r[16], x[30]);
423+
TLWEAdd<P>(r[42], r[16], x[28]);
424+
TLWEAdd<P>(r[43], r[18], x[19]);
425+
TLWEAdd<P>(r[44], r[19], x[19]);
426+
TLWEAdd<P>(r[45], r[19], r[12]);
427+
TLWEAdd<P>(r[46], r[20], x[26]);
428+
TLWEAdd<P>(r[47], r[20], r[13]);
429+
TLWEAdd<P>(r[48], r[21], x[17]);
430+
TLWEAdd<P>(r[49], r[22], x[25]);
431+
TLWEAdd<P>(r[50], r[22], x[17]);
432+
TLWEAdd<P>(r[51], r[23], x[16]);
433+
TLWEAdd<P>(r[52], r[23], x[ 9]);
434+
TLWEAdd<P>(r[53], r[24], x[18]);
435+
TLWEAdd<P>(r[54], r[24], x[12]);
436+
437+
// --- outputs that depend only on r0 … r54 -------------------------------
438+
TLWEAdd<P>(y[15], r[25], r[ 5]);
439+
TLWEAdd<P>(y[13], r[26], r[ 6]);
440+
TLWEAdd<P>(y[ 5], r[27], r[ 7]);
441+
TLWEAdd<P>(y[ 0], r[29], r[13]);
442+
TLWEAdd<P>(y[ 7], r[31], r[14]);
443+
TLWEAdd<P>(y[31], r[32], r[15]);
444+
TLWEAdd<P>(y[ 8], r[33], r[ 3]);
445+
TLWEAdd<P>(y[30], r[34], r[17]);
446+
TLWEAdd<P>(y[ 2], r[37], r[22]);
447+
448+
// --- remaining intermediates -------------------------------------------
449+
TLWEAdd<P>(r[55], r[37], r[28]);
450+
TLWEAdd<P>(r[56], r[40], x[21]);
451+
TLWEAdd<P>(r[57], r[40], r[13]);
452+
TLWEAdd<P>(y[ 6], r[41], r[ 5]);
453+
TLWEAdd<P>(r[58], r[42], x[29]);
454+
TLWEAdd<P>(r[59], r[43], r[ 4]);
455+
TLWEAdd<P>(r[60], r[44], x[ 2]);
456+
TLWEAdd<P>(y[11], r[44], r[38]);
457+
TLWEAdd<P>(y[28], r[45], r[36]);
458+
TLWEAdd<P>(r[61], r[46], r[45]);
459+
TLWEAdd<P>(r[62], r[47], x[10]);
460+
TLWEAdd<P>(y[ 4], r[47], r[35]);
461+
TLWEAdd<P>(y[18], r[48], r[ 9]);
462+
TLWEAdd<P>(y[10], r[48], r[11]);
463+
TLWEAdd<P>(y[17], r[49], r[30]);
464+
TLWEAdd<P>(r[63], r[50], r[29]);
465+
TLWEAdd<P>(y[24], r[51], r[12]);
466+
TLWEAdd<P>(y[16], r[51], r[30]);
467+
TLWEAdd<P>(r[64], r[51], r[33]);
468+
TLWEAdd<P>(y[ 1], r[52], r[39]);
469+
TLWEAdd<P>(y[19], r[53], r[46]);
470+
TLWEAdd<P>(y[20], r[54], r[43]);
471+
TLWEAdd<P>(y[26], r[55], r[48]);
472+
TLWEAdd<P>(y[14], r[56], x[13]);
473+
TLWEAdd<P>(y[22], r[56], r[34]);
474+
TLWEAdd<P>(y[23], r[57], r[14]);
475+
TLWEAdd<P>(y[21], r[58], x[20]);
476+
TLWEAdd<P>(y[29], r[58], r[27]);
477+
TLWEAdd<P>(y[12], r[59], r[ 8]);
478+
TLWEAdd<P>(y[27], r[61], r[60]);
479+
TLWEAdd<P>(y[ 3], r[62], r[60]);
480+
481+
// y25 depends on y24 already produced
482+
TLWEAdd<P>(y[25], y[24], r[63]);
483+
484+
TLWEAdd<P>(y[ 9], r[64], r[28]);
485+
}
486+
360487
// https://eprint.iacr.org/2019/833
361488
template <class P>
362489
void MixColumns(std::array<TLWE<P>, 128> &state) {
@@ -374,6 +501,7 @@ void MixColumns(std::array<TLWE<P>, 128> &state) {
374501
// Apply the MixColumn transformation
375502
std::array<TLWE<P>, 32> y_out; // For final output bits y0...y31
376503
MixColumn<P>(y_out, x);
504+
// MixColumnDepth4<P>(y_out, x);
377505

378506
// Place the resulting 32-bit column (y_out) back into the state array
379507
for (int i = 0; i < 4; ++i)
@@ -466,17 +594,49 @@ void AESEnc(std::array<TLWE<typename brP::targetP>, 128> &cipher,
466594
TLWEAdd<typename iksP::domainP>(state[i*Nb*8+j*8+k], plain[j*4*8+i*8+k], expandedkey[0][j*4*8+i*8+k]);
467595
state[i*Nb*8+j*8+k][iksP::domainP::k * iksP::domainP::n] += 1ULL << (std::numeric_limits<typename iksP::domainP::T>::digits - 2);
468596
}
469-
// state[i*Nb+j] = plain[j*4+i];]
470-
// AddRoundKey<typename iksP::domainP>(state, expandedkey[0]);
471597

472598
// Rounds
473599
for (int round = 1; round < Nr; round++) {
474-
SubBytes<iksP, brP>(state, state, ek);
600+
SubBytes<iksP, brP>(state, ek);
601+
ShiftRows<typename brP::targetP>(state);
602+
MixColumns<typename brP::targetP>(state);
603+
AddRoundKey<typename brP::targetP>(state, expandedkey[round]);
604+
}
605+
SubBytes<iksP, brP>(state, ek);
606+
ShiftRows<typename brP::targetP>(state);
607+
AddRoundKey<typename brP::targetP>(state, expandedkey[Nr]);
608+
609+
// Copy state to ciphertext with transposition
610+
for (int i = 0; i < 4; i++)
611+
for (int j = 0; j < Nb; j++)
612+
for(int k = 0; k < 8; k++)
613+
cipher[j*4*8+i*8+k] = state[i*Nb*8+j*8+k];
614+
}
615+
616+
template <class iksP, class brP, class cbiksP, class cbbrP>
617+
void AESEnc(std::array<TLWE<typename brP::targetP>, 128> &cipher,
618+
const std::array<TLWE<typename iksP::domainP>, 128> &plain,
619+
const std::array<std::array<TLWE<typename brP::targetP>, 128>, Nr+1> &expandedkey,
620+
EvalKey &ek)
621+
{
622+
std::array<TLWE<typename iksP::domainP>, 128> state;
623+
// Copy plaintext to state with transposition
624+
// Initial AddRoundKey
625+
for (int i = 0; i < 4; i++)
626+
for (int j = 0; j < Nb; j++)
627+
for(int k = 0; k < 8; k++){
628+
TLWEAdd<typename iksP::domainP>(state[i*Nb*8+j*8+k], plain[j*4*8+i*8+k], expandedkey[0][j*4*8+i*8+k]);
629+
state[i*Nb*8+j*8+k][iksP::domainP::k * iksP::domainP::n] += 1ULL << (std::numeric_limits<typename iksP::domainP::T>::digits - 2);
630+
}
631+
632+
// Rounds
633+
for (int round = 1; round < Nr; round++) {
634+
SubBytes<iksP, brP, cbiksP, cbbrP>(state, ek);
475635
ShiftRows<typename brP::targetP>(state);
476636
MixColumns<typename brP::targetP>(state);
477637
AddRoundKey<typename brP::targetP>(state, expandedkey[round]);
478638
}
479-
SubBytes<iksP, brP>(state, state, ek);
639+
SubBytes<iksP, brP, cbiksP, cbbrP>(state, ek);
480640
ShiftRows<typename brP::targetP>(state);
481641
AddRoundKey<typename brP::targetP>(state, expandedkey[Nr]);
482642

include/gatebootstrapping.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ constexpr Polynomial<P> μpolygen()
249249
}
250250

251251
template <class bkP, typename bkP::targetP::T μ, class iksP>
252-
void GateBootstrapping(TLWE<typename bkP::domainP> &res,
252+
void GateBootstrapping(TLWE<typename iksP::targetP> &res,
253253
const TLWE<typename bkP::domainP> &tlwe,
254254
const EvalKey &ek)
255255
{
@@ -260,7 +260,7 @@ void GateBootstrapping(TLWE<typename bkP::domainP> &res,
260260
}
261261

262262
template <class iksP, class bkP, typename bkP::targetP::T μ>
263-
void GateBootstrapping(TLWE<typename iksP::domainP> &res,
263+
void GateBootstrapping(TLWE<typename bkP::targetP> &res,
264264
const TLWE<typename iksP::domainP> &tlwe,
265265
const EvalKey &ek)
266266
{

test/aesencWOexpansion.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77

88
int main()
99
{
10-
using brP = TFHEpp::lvl02param;
11-
using iksP = TFHEpp::lvl20param;
10+
using brP = TFHEpp::lvl01param;
11+
using iksP = TFHEpp::lvl10param;
12+
using cbiksP = TFHEpp::lvl20param;
13+
using cbbrP = TFHEpp::lvl02param;
14+
// using brP = TFHEpp::lvl02param;
15+
// using iksP = TFHEpp::lvl20param;
1216
std::random_device seed_gen;
1317
std::default_random_engine engine(seed_gen());
1418
std::uniform_int_distribution<uint32_t> binary(0, 1);
@@ -44,15 +48,18 @@ int main()
4448
num_test);
4549
TFHEpp::EvalKey ek;
4650
ek.emplacebkfft<brP>(*sk);
51+
ek.emplacebkfft<cbbrP>(*sk);
4752
ek.emplaceiksk<iksP>(*sk);
48-
ek.emplaceahk<typename brP::targetP>(*sk);
49-
ek.emplacecbsk<typename brP::targetP>(*sk);
53+
ek.emplaceiksk<cbiksP>(*sk);
54+
ek.emplaceahk<typename cbbrP::targetP>(*sk);
55+
ek.emplacecbsk<typename cbbrP::targetP>(*sk);
5056

5157
std::chrono::system_clock::time_point start, end;
5258
start = std::chrono::system_clock::now();
5359
for (int test = 0; test < num_test; test++) {
5460
std::cout << "test: " << test << std::endl;
55-
TFHEpp::AESEnc<iksP, brP>(cres[test], cin[test], cexpandedkey[test], ek);
61+
TFHEpp::AESEnc<iksP, brP, cbiksP, cbbrP>(cres[test], cin[test], cexpandedkey[test], ek);
62+
// TFHEpp::AESEnc<iksP, brP>(cres[test], cin[test], cexpandedkey[test], ek);
5663
}
5764

5865
end = std::chrono::system_clock::now();
@@ -70,7 +77,7 @@ int main()
7077
cres[i][j*8+k], *sk))
7178
pres |= (1 << k);
7279
}
73-
std::cout << "j: " << j << ",c: " << (int)c[j] << " pres: " << (int)pres << std::endl;
80+
// std::cout << "j: " << j << ",c: " << (int)c[j] << " pres: " << (int)pres << std::endl;
7481
assert(pres == c[j]);
7582
}
7683
}

test/mixcolumnDepth4.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <cassert>
2+
#include <chrono>
3+
#include <iostream>
4+
#include <memory>
5+
#include <random>
6+
#include <tfhe++.hpp>
7+
8+
9+
int main()
10+
{
11+
using P = TFHEpp::lvl2param;
12+
std::random_device seed_gen;
13+
std::default_random_engine engine(seed_gen());
14+
std::uniform_int_distribution<uint32_t> binary(0, 1);
15+
16+
std::unique_ptr<TFHEpp::SecretKey> sk(new TFHEpp::SecretKey());
17+
constexpr uint num_test = 1000;
18+
std::vector<std::array<TFHEpp::TLWE<P>,32>> cin(num_test);
19+
std::vector<std::array<uint8_t, 32>> pin(num_test);
20+
21+
for (int i = 0; i < num_test; i++) {
22+
for (int j = 0; j < 32; j++){
23+
pin[i][j] = binary(engine);
24+
cin[i][j] = TFHEpp::tlweSymEncrypt<P>(pin[i][j]?1ULL << (std::numeric_limits<typename P::T>::digits - 2):-(1ULL << (std::numeric_limits<typename P::T>::digits - 2)), *sk);
25+
cin[i][j][P::k * P::n] += 1ULL << (std::numeric_limits<typename P::T>::digits - 2);
26+
}
27+
}
28+
29+
std::vector<std::array<TFHEpp::TLWE<P>,32>> cres(num_test);
30+
std::vector<std::array<TFHEpp::TLWE<P>,32>> cref(num_test);
31+
32+
for (int test = 0; test < num_test; test++) {
33+
// std::cout << "test: " << test << std::endl;
34+
TFHEpp::MixColumn<P>(cref[test],cin[test]);
35+
TFHEpp::MixColumnDepth4<P>(cres[test],cin[test]);
36+
for (int j = 0; j < 32; j++){
37+
cref[test][j][P::k * P::n] -= 1ULL << (std::numeric_limits<typename P::T>::digits - 2);
38+
cres[test][j][P::k * P::n] -= 1ULL << (std::numeric_limits<typename P::T>::digits - 2);
39+
}
40+
}
41+
42+
for (int i = 0; i < num_test; i++) {
43+
uint32_t pincat = 0;
44+
uint32_t pres = 0;
45+
uint32_t pref = 0;
46+
for (int j = 0; j < 32; j++){
47+
pres |= (static_cast<uint32_t>(TFHEpp::tlweSymDecrypt<P>(cres[i][j], *sk))) << j;
48+
pref |= (static_cast<uint32_t>(TFHEpp::tlweSymDecrypt<P>(cref[i][j], *sk))) << j;
49+
pincat |= static_cast<uint32_t>(pin[i][j]) << j;
50+
}
51+
// std::bitset<32> bpin(pincat);
52+
// std::bitset<32> bpres(pres);
53+
// std::bitset<32> bpref(pref);
54+
// std::cout << "i: " << i << std::endl;
55+
// std::cout << "pin: " << bpin << std::endl;
56+
// std::cout << "pres: " << bpres << std::endl;
57+
// std::cout << "pref: " << bpref << std::endl;
58+
assert(pres == pref);
59+
}
60+
std::cout << "PASS" << std::endl;
61+
}

test/mixcolumns.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ int main()
7878
uint8_t byte = 0;
7979
for (int l = 0; l < 8; l++)
8080
byte |= pres[j*32 + k * 8 + l] << l;
81+
// std::cout <<"j: " << j << " k: " << k << std::endl;
8182
// std::cout << (int)state[j][k] << " " << (int)byte << std::endl;
8283
assert(state[j][k] == byte);
8384
}

0 commit comments

Comments
 (0)