@@ -370,21 +370,59 @@ impl ServerKey {
370
370
where
371
371
T : IntegerRadixCiphertext ,
372
372
{
373
+ if amount. blocks . is_empty ( ) || ct. blocks ( ) . is_empty ( ) {
374
+ return ct. clone ( ) ;
375
+ }
376
+
373
377
let message_bits_per_block = self . key . message_modulus . 0 . ilog2 ( ) as u64 ;
374
378
let carry_bits_per_block = self . key . carry_modulus . 0 . ilog2 ( ) as u64 ;
375
379
assert ! ( carry_bits_per_block >= message_bits_per_block) ;
380
+ assert ! ( message_bits_per_block. is_power_of_two( ) ) ;
376
381
377
- // Extracts bits and put them in the bit index 2 (=> bit number 3)
378
- // so that it is already aligned to the correct position of the cmux input,
379
- // and we reduce noise growth
380
- let mut shift_bit_extractor = BitExtractor :: with_final_offset (
381
- & amount. blocks ,
382
- self ,
383
- message_bits_per_block as usize ,
384
- message_bits_per_block as usize ,
385
- ) ;
382
+ if ct. blocks ( ) . len ( ) == 1 {
383
+ let lut = self
384
+ . key
385
+ . generate_lookup_table_bivariate ( |input, first_shift_block| {
386
+ let shift_within_block = first_shift_block % message_bits_per_block;
386
387
387
- assert ! ( message_bits_per_block. is_power_of_two( ) ) ;
388
+ match operation {
389
+ BarrelShifterOperation :: LeftShift => {
390
+ ( input << shift_within_block) % self . message_modulus ( ) . 0
391
+ }
392
+ BarrelShifterOperation :: LeftRotate => {
393
+ let shifted = ( input << shift_within_block) % self . message_modulus ( ) . 0 ;
394
+ let wrapped = input >> ( shift_within_block) ;
395
+ shifted | wrapped
396
+ }
397
+ BarrelShifterOperation :: RightRotate => {
398
+ let shifted = input >> shift_within_block;
399
+ let wrapped = ( input << shift_within_block) % self . message_modulus ( ) . 0 ;
400
+ wrapped | shifted
401
+ }
402
+ BarrelShifterOperation :: RightShift => {
403
+ if T :: IS_SIGNED {
404
+ let sign_bit_pos = message_bits_per_block - 1 ;
405
+ let sign_bit = ( input >> sign_bit_pos) & 1 ;
406
+ let padding_block = ( self . message_modulus ( ) . 0 - 1 ) * sign_bit;
407
+
408
+ // Pad with sign bits to 'simulate' an arithmetic shift
409
+ let input = ( padding_block << message_bits_per_block) | input;
410
+ ( input >> shift_within_block) % self . message_modulus ( ) . 0
411
+ } else {
412
+ input >> shift_within_block
413
+ }
414
+ }
415
+ }
416
+ } ) ;
417
+
418
+ let block = self . key . unchecked_apply_lookup_table_bivariate (
419
+ & ct. blocks ( ) [ 0 ] ,
420
+ & amount. blocks [ 0 ] ,
421
+ & lut,
422
+ ) ;
423
+
424
+ return T :: from_blocks ( vec ! [ block] ) ;
425
+ }
388
426
389
427
let message_for_block =
390
428
self . key
@@ -408,6 +446,45 @@ impl ServerKey {
408
446
b
409
447
}
410
448
} ) ;
449
+
450
+ // When doing right shift of a signed ciphertext, we do an arithmetic shift
451
+ // Thus, we need some special luts to be used on the last block
452
+ // (which has the sign bit)
453
+ let message_for_block_right_shift_signed =
454
+ if T :: IS_SIGNED && operation == BarrelShifterOperation :: RightShift {
455
+ let lut = self
456
+ . key
457
+ . generate_lookup_table_bivariate ( |input, first_shift_block| {
458
+ let shift_within_block = first_shift_block % message_bits_per_block;
459
+ let shift_to_next_block = ( first_shift_block / message_bits_per_block) % 2 ;
460
+
461
+ let sign_bit_pos = message_bits_per_block - 1 ;
462
+ let sign_bit = ( input >> sign_bit_pos) & 1 ;
463
+ let padding_block = ( self . message_modulus ( ) . 0 - 1 ) * sign_bit;
464
+
465
+ if shift_to_next_block == 1 {
466
+ padding_block
467
+ } else {
468
+ // Pad with sign bits to 'simulate' an arithmetic shift
469
+ let input = ( padding_block << message_bits_per_block) | input;
470
+ ( input >> shift_within_block) % self . message_modulus ( ) . 0
471
+ }
472
+ } ) ;
473
+ Some ( lut)
474
+ } else {
475
+ None
476
+ } ;
477
+
478
+ // Extracts bits and put them in the bit index 2 (=> bit number 3)
479
+ // so that it is already aligned to the correct position of the cmux input,
480
+ // and we reduce noise growth
481
+ let mut shift_bit_extractor = BitExtractor :: with_final_offset (
482
+ & amount. blocks ,
483
+ self ,
484
+ message_bits_per_block as usize ,
485
+ message_bits_per_block as usize ,
486
+ ) ;
487
+
411
488
let message_for_next_block =
412
489
self . key
413
490
. generate_lookup_table_bivariate ( |previous, first_shift_block| {
@@ -467,34 +544,6 @@ impl ServerKey {
467
544
}
468
545
} ) ;
469
546
470
- // When doing right shift of a signed ciphertext, we do an arithmetic shift
471
- // Thus, we need some special luts to be used on the last block
472
- // (which has the sign big)
473
- let message_for_block_right_shift_signed =
474
- if T :: IS_SIGNED && operation == BarrelShifterOperation :: RightShift {
475
- let lut = self
476
- . key
477
- . generate_lookup_table_bivariate ( |input, first_shift_block| {
478
- let shift_within_block = first_shift_block % message_bits_per_block;
479
- let shift_to_next_block = ( first_shift_block / message_bits_per_block) % 2 ;
480
-
481
- let sign_bit_pos = message_bits_per_block - 1 ;
482
- let sign_bit = ( input >> sign_bit_pos) & 1 ;
483
- let padding_block = ( self . message_modulus ( ) . 0 - 1 ) * sign_bit;
484
-
485
- if shift_to_next_block == 1 {
486
- padding_block
487
- } else {
488
- // Pad with sign bits to 'simulate' an arithmetic shift
489
- let input = ( padding_block << message_bits_per_block) | input;
490
- ( input >> shift_within_block) % self . message_modulus ( ) . 0
491
- }
492
- } ) ;
493
- Some ( lut)
494
- } else {
495
- None
496
- } ;
497
-
498
547
let message_for_next_block_right_shift_signed = if T :: IS_SIGNED
499
548
&& operation == BarrelShifterOperation :: RightShift
500
549
{
@@ -693,7 +742,8 @@ impl ServerKey {
693
742
) where
694
743
T : IntegerRadixCiphertext ,
695
744
{
696
- let num_blocks = shift. blocks . len ( ) ;
745
+ // What matters is the len of the ct to shift, not the `shift` len
746
+ let num_blocks = ct. blocks ( ) . len ( ) ;
697
747
let message_bits_per_block = self . key . message_modulus . 0 . ilog2 ( ) as u64 ;
698
748
let carry_bits_per_block = self . key . carry_modulus . 0 . ilog2 ( ) as u64 ;
699
749
let total_nb_bits = message_bits_per_block * num_blocks as u64 ;
0 commit comments