1
1
#pragma once
2
2
3
3
#include < array>
4
+ #include < span>
5
+ #include < bit>
4
6
5
7
#include " params.hpp"
6
8
#include " trgsw.hpp"
@@ -183,6 +185,72 @@ void SubsetIdentityKeySwitch(TLWE<typename P::targetP> &res,
183
185
}
184
186
}
185
187
188
+ template <class P >
189
+ void PrivKeySwitch (TRLWE<typename P::targetP> &res,
190
+ const TLWE<typename P::domainP> &tlwe,
191
+ const PrivateKeySwitchingKey<P> &privksk)
192
+ {
193
+ constexpr typename P::domainP::T roundoffset =
194
+ 1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
195
+ (1 + P::basebit * P::t));
196
+
197
+ // Koga's Optimization
198
+ constexpr typename P::domainP::T offset = iksoffsetgen<P>();
199
+ constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1 ;
200
+ constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1 );
201
+ res = {};
202
+ for (int i = 0 ; i <= P::domainP::k * P::domainP::n; i++) {
203
+ const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset;
204
+
205
+ for (int j = 0 ; j < P::t; j++) {
206
+ const int32_t aij =
207
+ ((aibar >>
208
+ (std::numeric_limits<typename P::domainP::T>::digits -
209
+ (j + 1 ) * P::basebit)) &
210
+ mask) -
211
+ halfbase;
212
+
213
+ if (aij > 0 )
214
+ for (int k = 0 ; k < P::targetP::k + 1 ; k++)
215
+ for (int p = 0 ; p < P::targetP::n; p++)
216
+ res[k][p] -= privksk[i][j][aij - 1 ][k][p];
217
+ else if (aij < 0 )
218
+ for (int k = 0 ; k < P::targetP::k + 1 ; k++)
219
+ for (int p = 0 ; p < P::targetP::n; p++)
220
+ res[k][p] += privksk[i][j][abs (aij) - 1 ][k][p];
221
+ }
222
+ }
223
+ }
224
+
225
+ template <class P >
226
+ void SubsetPrivKeySwitch (TRLWE<typename P::targetP> &res,
227
+ const TLWE<typename P::targetP> &tlwe,
228
+ const SubsetPrivateKeySwitchingKey<P> &privksk)
229
+ {
230
+ constexpr uint32_t mask = (1 << P::basebit) - 1 ;
231
+ constexpr uint64_t prec_offset =
232
+ 1ULL << (std::numeric_limits<typename P::targetP::T>::digits -
233
+ (1 + P::basebit * P::t));
234
+
235
+ res = {};
236
+ for (int i = 0 ; i <= P::targetP::k * P::targetP::n; i++) {
237
+ const typename P::targetP::T aibar = tlwe[i] + prec_offset;
238
+
239
+ for (int j = 0 ; j < P::t; j++) {
240
+ const typename P::domainP::T aij =
241
+ (aibar >> (std::numeric_limits<typename P::targetP::T>::digits -
242
+ (j + 1 ) * P::basebit)) &
243
+ mask;
244
+
245
+ if (aij != 0 ) {
246
+ for (int p = 0 ; p < P::targetP::n; p++)
247
+ for (int k = 0 ; k < P::targetP::k + 1 ; k++)
248
+ res[k][p] -= privksk[i][j][aij - 1 ][k][p];
249
+ }
250
+ }
251
+ }
252
+ }
253
+
186
254
template <class P >
187
255
void TLWE2TRLWEIKS (TRLWE<typename P::targetP> &res,
188
256
const TLWE<typename P::domainP> &tlwe,
@@ -239,113 +307,85 @@ void EvalAuto(TRLWE<P> &res, const TRLWE<P> &trlwe, const int d,
239
307
}
240
308
241
309
// https://eprint.iacr.org/2024/1318
310
+ // Reversed order but this is easily proved by packing trivial all 0 TRLWE.
242
311
// TODO: They says we should divide by N first, not by 2 for each step. Why?
243
312
template <class P >
244
313
void AnnihilateKeySwitching (TRLWE<P> &res, const TRLWE<P> &trlwe,
245
314
const AnnihilateKey<P> &ahk)
246
315
{
247
316
res = trlwe;
248
- for (int j = 0 ; j < (P::k + 1 ) * P::n; j++) res[0 ][j] /= P::n;
317
+ // for (int j = 0; j < (P::k + 1) * P::n; j++) res[0][j] /= P::n;
249
318
for (int i = 0 ; i < P::nbit; i++) {
319
+ for (int j = 0 ; j < (P::k + 1 ) * P::n; j++) res[0 ][j] /= 2 ;
250
320
TRLWE<P> evaledauto;
251
- EvalAuto<P>(evaledauto, res, (1 << (P::nbit - i )) + 1 , ahk[i]);
321
+ EvalAuto<P>(evaledauto, res, (1 << (i+ 1 )) + 1 , ahk[i]);
252
322
for (int j = 0 ; j < (P::k + 1 ) * P::n; j++)
253
323
res[0 ][j] += evaledauto[0 ][j];
254
324
}
255
325
}
256
326
257
- template <class P , uint num_func>
258
- void AnnihilatePrivateKeySwitching (
259
- std::array<TRLWE<P>, num_func> &res, const TRLWE<P> &trlwe,
260
- const AnnihilateKey<P> &ahk,
261
- const std::array<TRGSWFFT<P>, num_func> &privks)
262
- {
263
- static_assert (num_func > 0 , " num_func must be bigger than 0" );
264
- res[num_func - 1 ] = trlwe;
265
- for (int i = 0 ; i < P::nbit - 1 ; i++) {
266
- TRLWE<P> evaledauto;
267
- EvalAuto<P>(evaledauto, res[num_func - 1 ], (1 << (P::nbit - i)) + 1 ,
268
- ahk[i]);
269
- for (int j = 0 ; j < (P::k + 1 ) * P::n; j++)
270
- res[num_func - 1 ][0 ][j] += evaledauto[0 ][j];
271
- }
272
- for (int i = 0 ; i < num_func; i++) {
273
- TRLWE<P> evaledauto;
274
- EvalAuto<P>(evaledauto, res[num_func - 1 ], (1 << (P::nbit - i)) + 1 ,
275
- privks[i]);
276
- for (int j = 0 ; j < (P::k + 1 ) * P::n; j++)
277
- res[i][0 ][j] += res[num_func - 1 ][0 ][j] + evaledauto[0 ][j];
278
- }
279
- }
327
+ // template <class P, uint num_func>
328
+ // void AnnihilatePrivateKeySwitching(
329
+ // std::array<TRLWE<P>, num_func> &res, const TRLWE<P> &trlwe,
330
+ // const AnnihilateKey<P> &ahk,
331
+ // const std::array<TRGSWFFT<P>, num_func> &privks)
332
+ // {
333
+ // static_assert(num_func > 0, "num_func must be bigger than 0");
334
+ // res[num_func - 1] = trlwe;
335
+ // for (int i = 0; i < P::nbit - 1; i++) {
336
+ // TRLWE<P> evaledauto;
337
+ // EvalAuto<P>(evaledauto, res[num_func - 1], (1 << (P::nbit - i)) + 1,
338
+ // ahk[i]);
339
+ // for (int j = 0; j < (P::k + 1) * P::n; j++)
340
+ // res[num_func - 1][0][j] += evaledauto[0][j];
341
+ // }
342
+ // for (int i = 0; i < num_func; i++) {
343
+ // TRLWE<P> evaledauto;
344
+ // EvalAuto<P>(evaledauto, res[num_func - 1], (1 << (P::nbit - i)) + 1,
345
+ // privks[i]);
346
+ // for (int j = 0; j < (P::k + 1) * P::n; j++)
347
+ // res[i][0][j] += res[num_func - 1][0][j] + evaledauto[0][j];
348
+ // }
349
+ // }
280
350
281
- template <class P >
282
- void PrivKeySwitch (TRLWE<typename P::targetP> &res,
283
- const TLWE<typename P::domainP> &tlwe,
284
- const PrivateKeySwitchingKey<P> &privksk)
285
- {
286
- constexpr typename P::domainP::T roundoffset =
287
- 1ULL << (std::numeric_limits<typename P::domainP::T>::digits -
288
- (1 + P::basebit * P::t));
289
-
290
- // Koga's Optimization
291
- constexpr typename P::domainP::T offset = iksoffsetgen<P>();
292
- constexpr typename P::domainP::T mask = (1ULL << P::basebit) - 1 ;
293
- constexpr typename P::domainP::T halfbase = 1ULL << (P::basebit - 1 );
294
- res = {};
295
- for (int i = 0 ; i <= P::domainP::k * P::domainP::n; i++) {
296
- const typename P::domainP::T aibar = tlwe[i] + offset + roundoffset;
297
-
298
- for (int j = 0 ; j < P::t; j++) {
299
- const int32_t aij =
300
- ((aibar >>
301
- (std::numeric_limits<typename P::domainP::T>::digits -
302
- (j + 1 ) * P::basebit)) &
303
- mask) -
304
- halfbase;
305
-
306
- if (aij > 0 )
307
- for (int k = 0 ; k < P::targetP::k + 1 ; k++)
308
- for (int p = 0 ; p < P::targetP::n; p++)
309
- res[k][p] -= privksk[i][j][aij - 1 ][k][p];
310
- else if (aij < 0 )
311
- for (int k = 0 ; k < P::targetP::k + 1 ; k++)
312
- for (int p = 0 ; p < P::targetP::n; p++)
313
- res[k][p] += privksk[i][j][abs (aij) - 1 ][k][p];
314
- }
315
- }
316
- }
317
-
318
- template <class P >
319
- void SubsetPrivKeySwitch (TRLWE<typename P::targetP> &res,
320
- const TLWE<typename P::targetP> &tlwe,
321
- const SubsetPrivateKeySwitchingKey<P> &privksk)
322
- {
323
- constexpr uint32_t mask = (1 << P::basebit) - 1 ;
324
- constexpr uint64_t prec_offset =
325
- 1ULL << (std::numeric_limits<typename P::targetP::T>::digits -
326
- (1 + P::basebit * P::t));
327
-
328
- res = {};
329
- for (int i = 0 ; i <= P::targetP::k * P::targetP::n; i++) {
330
- const typename P::targetP::T aibar = tlwe[i] + prec_offset;
331
-
332
- for (int j = 0 ; j < P::t; j++) {
333
- const typename P::domainP::T aij =
334
- (aibar >> (std::numeric_limits<typename P::targetP::T>::digits -
335
- (j + 1 ) * P::basebit)) &
336
- mask;
351
+ // template <class P, uint num_tlwe>
352
+ // void AnnihilatePacking(TRLWE<P> &res, const std::array<TLWE<P>, num_tlwe> &tlwes,
353
+ // const AnnihilateKey<P> &ahk)
354
+ // {
355
+ // static_assert(std::has_single_bit(num_tlwe), "Currently, num_tlwe must be power of 2");
356
+ // std::array<TRLWE<P>, num_tlwe> trlwes;
357
+ // constexpr uint l = std::count_zero(num_tlwe);
358
+ // for (int i = 0; i < num_tlwe; i++) {
359
+ // InvSampleExtractIndex<P>(trlwes[i], tlwes[i], 0);
360
+ // for (int j = 0; j <= P::k * P::n; j++)//rest are known to be 0
361
+ // trlwes[i][0][j] /= P::n;
362
+ // }
363
+ // // Using res as a temporary variable
364
+ // for (int i = 0; i < l; i++){
365
+ // constexpr uint stride = 1 << (l - i - 1);
366
+ // for(int j = 0; j < stride; j++){
367
+ // PolynomialMulByXai<P>(res, trlwes[stride+j], P::n >> i);
368
+ // for(int k = 0; i < (P::k+1) * P::n; k++)
369
+ // trlwes[stride+j][k] = trlwes[j][k] - res[k];
370
+ // for(int k = 0; i < (P::k+1) * P::n; k++)
371
+ // trlwes[j][k] += res[k];
372
+ // EvalAuto<P>(res, trlwes[stride+j], (1 << (P::nbit - i)) + 1, ahk[i]);
373
+ // for(int k = 0; i < (P::k+1) * P::n; k++)
374
+ // trlwes[j][k] += res[k];
375
+ // }
376
+ // }
377
+ // res = trlwes[0];
378
+ // // using trlews[0] and trlwes[1] as temporary variables
379
+ // for (int i = l; i < P::nbit; i++) {
380
+ // PolynomialMulByXai<P>(res, trlwes[(1<<i)+j], P::n >> i);
381
+ // EvalAuto<P>(evaledauto, res, (1 << (P::nbit - i)) + 1, ahk[i]);
382
+ // for (int j = 0; j < (P::k + 1) * P::n; j++)
383
+ // res[0][j] += evaledauto[0][j];
384
+ // }
385
+ // }
337
386
338
- if (aij != 0 ) {
339
- for (int p = 0 ; p < P::targetP::n; p++)
340
- for (int k = 0 ; k < P::targetP::k + 1 ; k++)
341
- res[k][p] -= privksk[i][j][aij - 1 ][k][p];
342
- }
343
- }
344
- }
345
- }
346
-
347
- template <class P >
348
- void PackLWEs (TRLWE<P> &res, const std::vector<TLWE<P>> &tlwe,
387
+ template <class P , class Container >
388
+ void PackLWEs (TRLWE<P> &res, const Container &tlwe,
349
389
const AnnihilateKey<P> &ahk, const uint l, const uint offset,
350
390
const uint interval)
351
391
{
@@ -363,19 +403,17 @@ void PackLWEs(TRLWE<P> &res, const std::vector<TLWE<P>> &tlwe,
363
403
tempeven[i][j] /= 2 ;
364
404
tempoddmul[i][j] /= 2 ;
365
405
tempodd[i][j] = tempeven[i][j] - tempoddmul[i][j];
366
- // tempodd[i][j] = (tempeven[i][j] - tempoddmul[i][j])/2;
367
406
}
368
407
}
369
- EvalAuto<P>(res, tempodd, (1 << l) + 1 , ahk[P::nbit - l ]);
408
+ EvalAuto<P>(res, tempodd, (1 << l) + 1 , ahk[l- 1 ]);
370
409
for (int i = 0 ; i < P::k + 1 ; i++)
371
410
for (int j = 0 ; j < P::n; j++)
372
411
res[i][j] += tempeven[i][j] + tempoddmul[i][j];
373
- // res[i][j] += (tempeven[i][j] + tempoddmul[i][j])/2;
374
412
}
375
413
}
376
414
377
415
template <class P >
378
- void TLWE2TRLWEChengsPacking (TRLWE<P> &res, std::vector<TLWE<P>> &tlwe,
416
+ void TLWE2TRLWEChensPacking (TRLWE<P> &res, std::vector<TLWE<P>> &tlwe,
379
417
const AnnihilateKey<P> &ahk)
380
418
{
381
419
uint l = std::bit_width (tlwe.size ()) - 1 ;
@@ -384,15 +422,40 @@ void TLWE2TRLWEChengsPacking(TRLWE<P> &res, std::vector<TLWE<P>> &tlwe,
384
422
tlwe.resize (1 << l);
385
423
}
386
424
PackLWEs<P>(res, tlwe, ahk, l, 0 , 1 );
387
- for (int i = 0 ; i < P::nbit - l ; i++) {
425
+ for (int i = l ; i < P::nbit; i++) {
388
426
TRLWE<P> evaledauto;
389
427
for (int j = 0 ; j < (P::k + 1 ) * P::n; j++) res[0 ][j] /= 2 ;
390
- EvalAuto<P>(evaledauto, res, (1 << (P::nbit - i )) + 1 , ahk[i]);
428
+ EvalAuto<P>(evaledauto, res, (1 << (i+ 1 )) + 1 , ahk[i]);
391
429
for (int j = 0 ; j < (P::k + 1 ) * P::n; j++)
392
430
res[0 ][j] += evaledauto[0 ][j];
393
431
}
394
432
}
395
433
434
+ template <class P , uint num_tlwe>
435
+ void TLWE2TablePacking (TRLWE<P> &res, std::array<TLWE<P>,num_tlwe> &tlwe,
436
+ const AnnihilateKey<P> &ahk)
437
+ {
438
+ static_assert (std::has_single_bit (num_tlwe), " Currently, num_tlwe must be power of 2" );
439
+ constexpr uint l = std::countr_zero (num_tlwe);
440
+ PackLWEs<P>(res, tlwe, ahk, l, 0 , 1 );
441
+ for (int i = l; i < P::nbit; i++) {
442
+ TRLWE<P> tempmul;
443
+ for (int j = 0 ; j < P::k + 1 ; j++)
444
+ PolynomialMulByXai<P>(tempmul[j], res[j], P::n >> (i+1 ));
445
+ TRLWE<P> tempsub;
446
+ for (int j = 0 ; j < (P::k + 1 ) * P::n; j++){
447
+ res[0 ][j] /= 2 ;
448
+ tempmul[0 ][j] /= 2 ;
449
+ tempsub[0 ][j] = res[0 ][j] - tempmul[0 ][j];
450
+ res[0 ][j] += tempmul[0 ][j];
451
+ }
452
+ // reuse tempmul
453
+ EvalAuto<P>(tempmul, tempsub, (1 << (i+1 )) + 1 , ahk[i]);
454
+ for (int j = 0 ; j < (P::k + 1 ) * P::n; j++)
455
+ res[0 ][j] += tempmul[0 ][j];
456
+ }
457
+ }
458
+
396
459
template <class P >
397
460
void PackLWEsLSB (TRLWE<P> &res, const std::vector<TLWE<P>> &tlwe,
398
461
const AnnihilateKey<P> &ahk, const uint l, const uint offset,
@@ -415,14 +478,12 @@ void PackLWEsLSB(TRLWE<P> &res, const std::vector<TLWE<P>> &tlwe,
415
478
tempeven[i][j] /= 2 ;
416
479
tempoddmul[i][j] /= 2 ;
417
480
tempodd[i][j] = tempeven[i][j] - tempoddmul[i][j];
418
- // tempodd[i][j] = (tempeven[i][j] - tempoddmul[i][j])/2;
419
481
}
420
482
}
421
- EvalAuto<P>(res, tempodd, (1 << l) + 1 , ahk[P::nbit - l ]);
483
+ EvalAuto<P>(res, tempodd, (1 << l) + 1 , ahk[l- 1 ]);
422
484
for (int i = 0 ; i < P::k + 1 ; i++)
423
485
for (int j = 0 ; j < P::n; j++)
424
486
res[i][j] += tempeven[i][j] + tempoddmul[i][j];
425
- // res[i][j] += (tempeven[i][j] + tempoddmul[i][j])/2;
426
487
}
427
488
}
428
489
0 commit comments