Skip to content

Commit 866937d

Browse files
committed
Added InvMixColumns
1 parent e910460 commit 866937d

File tree

3 files changed

+248
-101
lines changed

3 files changed

+248
-101
lines changed

include/aes.hpp

Lines changed: 164 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <AES.h>
3+
#include <ranges>
34
// Transciphering by AES
45
// Based on Hippogryph
56
namespace TFHEpp {
@@ -123,6 +124,112 @@ void InvShiftRows(std::array<TLWE<P>, 128> &res)
123124
ShiftRow<P,3,aesNb-3>(res);
124125
}
125126

127+
template <class P>
128+
void MixColumn(std::array<TLWE<P>, 32>& y_out, const std::array<TLWE<P>, 32>& x){
129+
// Temporary variables for intermediate XOR results, based on the 92-gate circuit
130+
// Naming corresponds to t0...t59 and y0...y31 as per the circuit.
131+
// We will use 't' for all intermediate values and 'y_out' for the final 32 output bits.
132+
std::array<TLWE<P>, 60> t; // For t0...t59
133+
134+
// Implement the 92 XOR gates from Listing 1 in the provided PDF [cite: 28, 29]
135+
// Note: The paper uses various symbols for XOR (e.g., ^, ~, -, ´). We interpret all as XOR.
136+
// The indices for x will be 0-31.
137+
138+
// Implement the 92 XOR gates based on the user-provided circuit
139+
TLWEAdd<P>(t[0], x[0], x[8]);
140+
TLWEAdd<P>(t[1], x[16], x[24]);
141+
TLWEAdd<P>(t[2], x[1], x[9]);
142+
TLWEAdd<P>(t[3], x[17], x[25]);
143+
TLWEAdd<P>(t[4], x[2], x[10]);
144+
TLWEAdd<P>(t[5], x[18], x[26]);
145+
TLWEAdd<P>(t[6], x[3], x[11]);
146+
TLWEAdd<P>(t[7], x[19], x[27]);
147+
TLWEAdd<P>(t[8], x[4], x[12]);
148+
TLWEAdd<P>(t[9], x[20], x[28]);
149+
TLWEAdd<P>(t[10], x[5], x[13]);
150+
TLWEAdd<P>(t[11], x[21], x[29]);
151+
TLWEAdd<P>(t[12], x[6], x[14]);
152+
TLWEAdd<P>(t[13], x[22], x[30]);
153+
TLWEAdd<P>(t[14], x[23], x[31]);
154+
TLWEAdd<P>(t[15], x[7], x[15]);
155+
TLWEAdd<P>(t[16], x[8], t[1]);
156+
TLWEAdd<P>(y_out[0], t[15], t[16]);
157+
TLWEAdd<P>(t[17], x[7], x[23]);
158+
TLWEAdd<P>(t[18], x[24], t[0]);
159+
TLWEAdd<P>(y_out[16], t[14], t[18]);
160+
TLWEAdd<P>(t[19], t[1], y_out[16]);
161+
TLWEAdd<P>(y_out[24], t[17], t[19]);
162+
TLWEAdd<P>(t[20], x[27], t[14]);
163+
TLWEAdd<P>(t[21], t[0], y_out[0]);
164+
TLWEAdd<P>(y_out[8], t[17], t[21]);
165+
TLWEAdd<P>(t[22], t[5], t[20]);
166+
TLWEAdd<P>(y_out[19], t[6], t[22]);
167+
TLWEAdd<P>(t[23], x[11], t[15]);
168+
TLWEAdd<P>(t[24], t[7], t[23]);
169+
TLWEAdd<P>(y_out[3], t[4], t[24]);
170+
TLWEAdd<P>(t[25], x[2], x[18]);
171+
TLWEAdd<P>(t[26], t[17], t[25]);
172+
TLWEAdd<P>(t[27], t[9], t[23]);
173+
TLWEAdd<P>(t[28], t[8], t[20]);
174+
TLWEAdd<P>(t[29], x[10], t[2]);
175+
TLWEAdd<P>(y_out[2], t[5], t[29]);
176+
TLWEAdd<P>(t[30], x[26], t[3]);
177+
TLWEAdd<P>(y_out[18], t[4], t[30]);
178+
TLWEAdd<P>(t[31], x[9], x[25]);
179+
TLWEAdd<P>(t[32], t[25], t[31]);
180+
TLWEAdd<P>(y_out[10], t[30], t[32]);
181+
TLWEAdd<P>(y_out[26], t[29], t[32]);
182+
TLWEAdd<P>(t[33], x[1], t[18]);
183+
TLWEAdd<P>(t[34], x[30], t[11]);
184+
TLWEAdd<P>(y_out[22], t[12], t[34]);
185+
TLWEAdd<P>(t[35], x[14], t[13]);
186+
TLWEAdd<P>(y_out[6], t[10], t[35]);
187+
TLWEAdd<P>(t[36], x[5], x[21]);
188+
TLWEAdd<P>(t[37], x[30], t[17]);
189+
TLWEAdd<P>(t[38], x[17], t[16]);
190+
TLWEAdd<P>(t[39], x[13], t[8]);
191+
TLWEAdd<P>(y_out[5], t[11], t[39]);
192+
TLWEAdd<P>(t[40], x[12], t[36]);
193+
TLWEAdd<P>(t[41], x[29], t[9]);
194+
TLWEAdd<P>(y_out[21], t[10], t[41]);
195+
TLWEAdd<P>(t[42], x[28], t[40]);
196+
TLWEAdd<P>(y_out[13], t[41], t[42]);
197+
TLWEAdd<P>(y_out[29], t[39], t[42]);
198+
TLWEAdd<P>(t[43], x[15], t[12]);
199+
TLWEAdd<P>(y_out[7], t[14], t[43]);
200+
TLWEAdd<P>(t[44], x[14], t[37]);
201+
TLWEAdd<P>(y_out[31], t[43], t[44]);
202+
TLWEAdd<P>(t[45], x[31], t[13]);
203+
TLWEAdd<P>(y_out[15], t[44], t[45]);
204+
TLWEAdd<P>(y_out[23], t[15], t[45]);
205+
TLWEAdd<P>(t[46], t[12], t[36]);
206+
TLWEAdd<P>(y_out[14], y_out[6], t[46]);
207+
TLWEAdd<P>(t[47], t[31], t[33]);
208+
TLWEAdd<P>(y_out[17], t[19], t[47]);
209+
TLWEAdd<P>(t[48], t[6], y_out[3]);
210+
TLWEAdd<P>(y_out[11], t[26], t[48]);
211+
TLWEAdd<P>(t[49], t[2], t[38]);
212+
TLWEAdd<P>(y_out[25], y_out[24], t[49]);
213+
TLWEAdd<P>(t[50], t[7], y_out[19]);
214+
TLWEAdd<P>(y_out[27], t[26], t[50]);
215+
TLWEAdd<P>(t[51], x[22], t[46]);
216+
TLWEAdd<P>(y_out[30], t[11], t[51]);
217+
TLWEAdd<P>(t[52], x[19], t[28]);
218+
TLWEAdd<P>(y_out[20], x[28], t[52]);
219+
TLWEAdd<P>(t[53], x[3], t[27]);
220+
TLWEAdd<P>(y_out[4], x[12], t[53]);
221+
TLWEAdd<P>(t[54], t[3], t[33]);
222+
TLWEAdd<P>(y_out[9], y_out[8], t[54]);
223+
TLWEAdd<P>(t[55], t[21], t[31]);
224+
TLWEAdd<P>(y_out[1], t[38], t[55]);
225+
TLWEAdd<P>(t[56], x[4], t[17]);
226+
TLWEAdd<P>(t[57], x[19], t[56]);
227+
TLWEAdd<P>(y_out[12], t[27], t[57]);
228+
TLWEAdd<P>(t[58], x[3], t[28]);
229+
TLWEAdd<P>(t[59], t[17], t[58]);
230+
TLWEAdd<P>(y_out[28], x[20], t[59]);
231+
}
232+
126233
// https://eprint.iacr.org/2019/833
127234
template <class P>
128235
void MixColumns(std::array<TLWE<P>, 128> &state) {
@@ -137,109 +244,64 @@ void MixColumns(std::array<TLWE<P>, 128> &state) {
137244
x[i*8+j] = state[i*32 + col*8 + j];
138245
for (int i = 0; i < 32; ++i) x[i][P::k * P::n] += 1ULL << (std::numeric_limits<typename P::T>::digits - 2);
139246

140-
// Temporary variables for intermediate XOR results, based on the 92-gate circuit
141-
// Naming corresponds to t0...t59 and y0...y31 as per the circuit.
142-
// We will use 't' for all intermediate values and 'y_out' for the final 32 output bits.
143-
std::array<TLWE<P>, 60> t; // For t0...t59
247+
// Apply the MixColumn transformation
144248
std::array<TLWE<P>, 32> y_out; // For final output bits y0...y31
249+
MixColumn<P>(y_out, x);
250+
251+
// Place the resulting 32-bit column (y_out) back into the state array
252+
for (int i = 0; i < 4; ++i)
253+
for (int j = 0; j < 8; ++j)
254+
state[i*32 + col*8 + j] = y_out[i*8+j];
255+
}
256+
for(int i = 0; i < 128; i++)
257+
state[i][P::k*P::n] -= (1ULL << (std::numeric_limits<typename P::T>::digits - 2));
258+
}
259+
260+
template <class P>
261+
void xxtimes(std::array<TLWE<P>,8> &statebyte)
262+
{
263+
std::array<TLWE<P>, 8> tmp;
264+
tmp[0] = statebyte[6];
265+
TLWEAdd<P>(tmp[1], statebyte[6], statebyte[7]);
266+
TLWEAdd<P>(tmp[2], statebyte[0], statebyte[7]);
267+
TLWEAdd<P>(tmp[3], statebyte[1], statebyte[6]);
268+
TLWEAdd<P>(tmp[4], statebyte[2], tmp[1]);
269+
TLWEAdd<P>(tmp[5], statebyte[3], statebyte[7]);
270+
tmp[6] = statebyte[4];
271+
tmp[7] = statebyte[5];
272+
for(int i = 0; i < 8; i++)
273+
statebyte[i] = tmp[i];
274+
}
275+
276+
// https://doi.org/10.1007/s13389-017-0176-3
277+
template <class P>
278+
void InvMixColumns(std::array<TLWE<P>, 128> &state) {
279+
// The Inverse MixColumns operation is applied to each 32-bit column of the state.
280+
// The AES state is 128 bits, so there are 4 such columns.
281+
282+
for (int col = 0; col < 4; ++col) {
283+
// Extract the current 32-bit column into a working array (x0 to x31)
284+
std::array<TLWE<P>, 32> x; // Input bits for the current column
285+
for (int i = 0; i < 4; ++i)
286+
for (int j = 0; j < 8; ++j)
287+
x[i*8+j] = state[i*32 + col*8 + j];
288+
for (int i = 0; i < 32; ++i) x[i][P::k * P::n] += 1ULL << (std::numeric_limits<typename P::T>::digits - 2); // Shift to suppor XOR
145289

146-
// Implement the 92 XOR gates from Listing 1 in the provided PDF [cite: 28, 29]
147-
// Note: The paper uses various symbols for XOR (e.g., ^, ~, -, ´). We interpret all as XOR.
148-
// The indices for x will be 0-31.
149-
150-
// Implement the 92 XOR gates based on the user-provided circuit
151-
TLWEAdd<P>(t[0], x[0], x[8]);
152-
TLWEAdd<P>(t[1], x[16], x[24]);
153-
TLWEAdd<P>(t[2], x[1], x[9]);
154-
TLWEAdd<P>(t[3], x[17], x[25]);
155-
TLWEAdd<P>(t[4], x[2], x[10]);
156-
TLWEAdd<P>(t[5], x[18], x[26]);
157-
TLWEAdd<P>(t[6], x[3], x[11]);
158-
TLWEAdd<P>(t[7], x[19], x[27]);
159-
TLWEAdd<P>(t[8], x[4], x[12]);
160-
TLWEAdd<P>(t[9], x[20], x[28]);
161-
TLWEAdd<P>(t[10], x[5], x[13]);
162-
TLWEAdd<P>(t[11], x[21], x[29]);
163-
TLWEAdd<P>(t[12], x[6], x[14]);
164-
TLWEAdd<P>(t[13], x[22], x[30]);
165-
TLWEAdd<P>(t[14], x[23], x[31]);
166-
TLWEAdd<P>(t[15], x[7], x[15]);
167-
TLWEAdd<P>(t[16], x[8], t[1]);
168-
TLWEAdd<P>(y_out[0], t[15], t[16]);
169-
TLWEAdd<P>(t[17], x[7], x[23]);
170-
TLWEAdd<P>(t[18], x[24], t[0]);
171-
TLWEAdd<P>(y_out[16], t[14], t[18]);
172-
TLWEAdd<P>(t[19], t[1], y_out[16]);
173-
TLWEAdd<P>(y_out[24], t[17], t[19]);
174-
TLWEAdd<P>(t[20], x[27], t[14]);
175-
TLWEAdd<P>(t[21], t[0], y_out[0]);
176-
TLWEAdd<P>(y_out[8], t[17], t[21]);
177-
TLWEAdd<P>(t[22], t[5], t[20]);
178-
TLWEAdd<P>(y_out[19], t[6], t[22]);
179-
TLWEAdd<P>(t[23], x[11], t[15]);
180-
TLWEAdd<P>(t[24], t[7], t[23]);
181-
TLWEAdd<P>(y_out[3], t[4], t[24]);
182-
TLWEAdd<P>(t[25], x[2], x[18]);
183-
TLWEAdd<P>(t[26], t[17], t[25]);
184-
TLWEAdd<P>(t[27], t[9], t[23]);
185-
TLWEAdd<P>(t[28], t[8], t[20]);
186-
TLWEAdd<P>(t[29], x[10], t[2]);
187-
TLWEAdd<P>(y_out[2], t[5], t[29]);
188-
TLWEAdd<P>(t[30], x[26], t[3]);
189-
TLWEAdd<P>(y_out[18], t[4], t[30]);
190-
TLWEAdd<P>(t[31], x[9], x[25]);
191-
TLWEAdd<P>(t[32], t[25], t[31]);
192-
TLWEAdd<P>(y_out[10], t[30], t[32]);
193-
TLWEAdd<P>(y_out[26], t[29], t[32]);
194-
TLWEAdd<P>(t[33], x[1], t[18]);
195-
TLWEAdd<P>(t[34], x[30], t[11]);
196-
TLWEAdd<P>(y_out[22], t[12], t[34]);
197-
TLWEAdd<P>(t[35], x[14], t[13]);
198-
TLWEAdd<P>(y_out[6], t[10], t[35]);
199-
TLWEAdd<P>(t[36], x[5], x[21]);
200-
TLWEAdd<P>(t[37], x[30], t[17]);
201-
TLWEAdd<P>(t[38], x[17], t[16]);
202-
TLWEAdd<P>(t[39], x[13], t[8]);
203-
TLWEAdd<P>(y_out[5], t[11], t[39]);
204-
TLWEAdd<P>(t[40], x[12], t[36]);
205-
TLWEAdd<P>(t[41], x[29], t[9]);
206-
TLWEAdd<P>(y_out[21], t[10], t[41]);
207-
TLWEAdd<P>(t[42], x[28], t[40]);
208-
TLWEAdd<P>(y_out[13], t[41], t[42]);
209-
TLWEAdd<P>(y_out[29], t[39], t[42]);
210-
TLWEAdd<P>(t[43], x[15], t[12]);
211-
TLWEAdd<P>(y_out[7], t[14], t[43]);
212-
TLWEAdd<P>(t[44], x[14], t[37]);
213-
TLWEAdd<P>(y_out[31], t[43], t[44]);
214-
TLWEAdd<P>(t[45], x[31], t[13]);
215-
TLWEAdd<P>(y_out[15], t[44], t[45]);
216-
TLWEAdd<P>(y_out[23], t[15], t[45]);
217-
TLWEAdd<P>(t[46], t[12], t[36]);
218-
TLWEAdd<P>(y_out[14], y_out[6], t[46]);
219-
TLWEAdd<P>(t[47], t[31], t[33]);
220-
TLWEAdd<P>(y_out[17], t[19], t[47]);
221-
TLWEAdd<P>(t[48], t[6], y_out[3]);
222-
TLWEAdd<P>(y_out[11], t[26], t[48]);
223-
TLWEAdd<P>(t[49], t[2], t[38]);
224-
TLWEAdd<P>(y_out[25], y_out[24], t[49]);
225-
TLWEAdd<P>(t[50], t[7], y_out[19]);
226-
TLWEAdd<P>(y_out[27], t[26], t[50]);
227-
TLWEAdd<P>(t[51], x[22], t[46]);
228-
TLWEAdd<P>(y_out[30], t[11], t[51]);
229-
TLWEAdd<P>(t[52], x[19], t[28]);
230-
TLWEAdd<P>(y_out[20], x[28], t[52]);
231-
TLWEAdd<P>(t[53], x[3], t[27]);
232-
TLWEAdd<P>(y_out[4], x[12], t[53]);
233-
TLWEAdd<P>(t[54], t[3], t[33]);
234-
TLWEAdd<P>(y_out[9], y_out[8], t[54]);
235-
TLWEAdd<P>(t[55], t[21], t[31]);
236-
TLWEAdd<P>(y_out[1], t[38], t[55]);
237-
TLWEAdd<P>(t[56], x[4], t[17]);
238-
TLWEAdd<P>(t[57], x[19], t[56]);
239-
TLWEAdd<P>(y_out[12], t[27], t[57]);
240-
TLWEAdd<P>(t[58], x[3], t[28]);
241-
TLWEAdd<P>(t[59], t[17], t[58]);
242-
TLWEAdd<P>(y_out[28], x[20], t[59]);
290+
// Apply the Inverse MixColumn transformation (Decomposed to fix and MixColumn)
291+
for (int offset = 0; offset < 2; ++offset) {
292+
std::array<TLWE<P>, 8> statebyte;
293+
for(int i = 0; i < 8; i++)
294+
TLWEAdd<P>(statebyte[i], x[offset*8+i], x[(2+offset)*8+i]);
295+
xxtimes<P>(statebyte);
296+
for(int i = 0; i < 8; i++)
297+
TLWEAdd<P>(x[offset*8+i], x[offset*8+i], statebyte[i]);
298+
for(int i = 0; i < 8; i++)
299+
TLWEAdd<P>(x[(2+offset)*8+i], x[(2+offset)*8+i], statebyte[i]);
300+
}
301+
302+
// Apply the MixColumn transformation
303+
std::array<TLWE<P>, 32> y_out; // For final output bits y0...y31
304+
MixColumn<P>(y_out, x);
243305

244306
// Place the resulting 32-bit column (y_out) back into the state array
245307
for (int i = 0; i < 4; ++i)
@@ -249,4 +311,5 @@ void MixColumns(std::array<TLWE<P>, 128> &state) {
249311
for(int i = 0; i < 128; i++)
250312
state[i][P::k*P::n] -= (1ULL << (std::numeric_limits<typename P::T>::digits - 2));
251313
}
314+
252315
} // namespace TFHEpp

test/invmixcolumns.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#include <cassert>
2+
#include <chrono>
3+
#include <iostream>
4+
#include <memory>
5+
#include <random>
6+
#include <tfhe++.hpp>
7+
8+
void InvMixColumns(unsigned char state[4][4]) {
9+
unsigned char temp_state[4][4];
10+
11+
for (size_t i = 0; i < 4; ++i) {
12+
memset(temp_state[i], 0, 4);
13+
}
14+
15+
for (size_t i = 0; i < 4; ++i) {
16+
for (size_t k = 0; k < 4; ++k) {
17+
for (size_t j = 0; j < 4; ++j) {
18+
temp_state[i][j] ^= GF_MUL_TABLE[INV_CMDS[i][k]][state[k][j]];
19+
}
20+
}
21+
}
22+
23+
for (size_t i = 0; i < 4; ++i) {
24+
memcpy(state[i], temp_state[i], 4);
25+
}
26+
}
27+
28+
int main()
29+
{
30+
using P = TFHEpp::lvl2param;
31+
std::random_device seed_gen;
32+
std::default_random_engine engine(seed_gen());
33+
std::uniform_int_distribution<uint32_t> binary(0, 1);
34+
35+
std::unique_ptr<TFHEpp::SecretKey> sk(new TFHEpp::SecretKey());
36+
constexpr uint num_test = 1000;
37+
std::vector<std::array<TFHEpp::TLWE<P>,128>> cstate(num_test);
38+
std::vector<std::array<uint8_t, 128>> plaintext(num_test);
39+
40+
for (int i = 0; i < num_test; i++) {
41+
for (int j = 0; j < 128; j++){
42+
plaintext[i][j] = binary(engine);
43+
cstate[i][j] = TFHEpp::tlweSymEncrypt<P>(plaintext[i][j]?1ULL << (std::numeric_limits<typename P::T>::digits - 2):-(1ULL << (std::numeric_limits<typename P::T>::digits - 2)), *sk);
44+
}
45+
}
46+
47+
std::chrono::system_clock::time_point start, end;
48+
start = std::chrono::system_clock::now();
49+
for (int test = 0; test < num_test; test++) {
50+
// std::cout << "test: " << test << std::endl;
51+
TFHEpp::InvMixColumns<P>(cstate[test]);
52+
}
53+
54+
end = std::chrono::system_clock::now();
55+
double elapsed =
56+
std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
57+
.count();
58+
std::cout << elapsed / num_test << "ms" << std::endl;
59+
60+
for (int i = 0; i < num_test; i++) {
61+
std::array<uint8_t, 128> pres;
62+
for (int j = 0; j < 128; j++)
63+
pres[j] = TFHEpp::tlweSymDecrypt<P>(cstate[i][j], *sk);
64+
unsigned char state[4][4];
65+
for (int j = 0; j < 4; j++)
66+
for (int k = 0; k < 4; k++){
67+
uint8_t byte = 0;
68+
for (int l = 0; l < 8; l++)
69+
byte |= plaintext[i][j*32 + k * 8 + l] << l;
70+
state[j][k] = byte;
71+
}
72+
InvMixColumns(state);
73+
for (int j = 0; j < 4; j++)
74+
for (int k = 0; k < 4; k++){
75+
uint8_t byte = 0;
76+
for (int l = 0; l < 8; l++)
77+
byte |= pres[j*32 + k * 8 + l] << l;
78+
// std::cout << (int)state[j][k] << " " << (int)byte << std::endl;
79+
assert(state[j][k] == byte);
80+
}
81+
}
82+
std::cout << "PASS" << std::endl;
83+
}

test/mixcolumns.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,5 @@ int main()
8282
assert(state[j][k] == byte);
8383
}
8484
}
85+
std::cout << "PASS" << std::endl;
8586
}

0 commit comments

Comments
 (0)