Skip to content

Commit 726a771

Browse files
tmontaiguIceTDrinker
authored andcommitted
fix(integer): rotations/shifts < 2 blocks
This commit fixes a few bugs * The shift/rotate functions used when blocks encrypt a number of bits that is a power of 2 was causing a panic when working on one block. - Also, when the number of blocks was low (e.g 2 blocks with 2_2 params) a noise cleaning step was wrongly skipped * The function used when blocks encrypt non power of 2 number of bits also had a problem The test have been updated to test with different block sizes and check the noise level Overall these bugs only affected low block counts (e.g FheUint2, FheUint4) ciphertexts
1 parent f031272 commit 726a771

File tree

3 files changed

+313
-184
lines changed

3 files changed

+313
-184
lines changed

tfhe/src/integer/server_key/radix_parallel/block_shift.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ impl ServerKey {
6161
T: IntegerRadixCiphertext,
6262
{
6363
if d_range.is_empty() {
64-
return ct.clone();
64+
let mut result = ct.clone();
65+
result
66+
.blocks_mut()
67+
.par_iter_mut()
68+
.filter(|b| b.noise_level > NoiseLevel::NOMINAL)
69+
.for_each(|block| self.key.message_extract_assign(block));
70+
return result;
6571
}
6672

6773
assert!(

tfhe/src/integer/server_key/radix_parallel/shift.rs

Lines changed: 89 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -370,21 +370,59 @@ impl ServerKey {
370370
where
371371
T: IntegerRadixCiphertext,
372372
{
373+
if amount.blocks.is_empty() || ct.blocks().is_empty() {
374+
return ct.clone();
375+
}
376+
373377
let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
374378
let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
375379
assert!(carry_bits_per_block >= message_bits_per_block);
380+
assert!(message_bits_per_block.is_power_of_two());
376381

377-
// Extracts bits and put them in the bit index 2 (=> bit number 3)
378-
// so that it is already aligned to the correct position of the cmux input,
379-
// and we reduce noise growth
380-
let mut shift_bit_extractor = BitExtractor::with_final_offset(
381-
&amount.blocks,
382-
self,
383-
message_bits_per_block as usize,
384-
message_bits_per_block as usize,
385-
);
382+
if ct.blocks().len() == 1 {
383+
let lut = self
384+
.key
385+
.generate_lookup_table_bivariate(|input, first_shift_block| {
386+
let shift_within_block = first_shift_block % message_bits_per_block;
386387

387-
assert!(message_bits_per_block.is_power_of_two());
388+
match operation {
389+
BarrelShifterOperation::LeftShift => {
390+
(input << shift_within_block) % self.message_modulus().0
391+
}
392+
BarrelShifterOperation::LeftRotate => {
393+
let shifted = (input << shift_within_block) % self.message_modulus().0;
394+
let wrapped = input >> (shift_within_block);
395+
shifted | wrapped
396+
}
397+
BarrelShifterOperation::RightRotate => {
398+
let shifted = input >> shift_within_block;
399+
let wrapped = (input << shift_within_block) % self.message_modulus().0;
400+
wrapped | shifted
401+
}
402+
BarrelShifterOperation::RightShift => {
403+
if T::IS_SIGNED {
404+
let sign_bit_pos = message_bits_per_block - 1;
405+
let sign_bit = (input >> sign_bit_pos) & 1;
406+
let padding_block = (self.message_modulus().0 - 1) * sign_bit;
407+
408+
// Pad with sign bits to 'simulate' an arithmetic shift
409+
let input = (padding_block << message_bits_per_block) | input;
410+
(input >> shift_within_block) % self.message_modulus().0
411+
} else {
412+
input >> shift_within_block
413+
}
414+
}
415+
}
416+
});
417+
418+
let block = self.key.unchecked_apply_lookup_table_bivariate(
419+
&ct.blocks()[0],
420+
&amount.blocks[0],
421+
&lut,
422+
);
423+
424+
return T::from_blocks(vec![block]);
425+
}
388426

389427
let message_for_block =
390428
self.key
@@ -408,6 +446,45 @@ impl ServerKey {
408446
b
409447
}
410448
});
449+
450+
// When doing right shift of a signed ciphertext, we do an arithmetic shift
451+
// Thus, we need some special luts to be used on the last block
452+
// (which has the sign bit)
453+
let message_for_block_right_shift_signed =
454+
if T::IS_SIGNED && operation == BarrelShifterOperation::RightShift {
455+
let lut = self
456+
.key
457+
.generate_lookup_table_bivariate(|input, first_shift_block| {
458+
let shift_within_block = first_shift_block % message_bits_per_block;
459+
let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
460+
461+
let sign_bit_pos = message_bits_per_block - 1;
462+
let sign_bit = (input >> sign_bit_pos) & 1;
463+
let padding_block = (self.message_modulus().0 - 1) * sign_bit;
464+
465+
if shift_to_next_block == 1 {
466+
padding_block
467+
} else {
468+
// Pad with sign bits to 'simulate' an arithmetic shift
469+
let input = (padding_block << message_bits_per_block) | input;
470+
(input >> shift_within_block) % self.message_modulus().0
471+
}
472+
});
473+
Some(lut)
474+
} else {
475+
None
476+
};
477+
478+
// Extracts bits and put them in the bit index 2 (=> bit number 3)
479+
// so that it is already aligned to the correct position of the cmux input,
480+
// and we reduce noise growth
481+
let mut shift_bit_extractor = BitExtractor::with_final_offset(
482+
&amount.blocks,
483+
self,
484+
message_bits_per_block as usize,
485+
message_bits_per_block as usize,
486+
);
487+
411488
let message_for_next_block =
412489
self.key
413490
.generate_lookup_table_bivariate(|previous, first_shift_block| {
@@ -467,34 +544,6 @@ impl ServerKey {
467544
}
468545
});
469546

470-
// When doing right shift of a signed ciphertext, we do an arithmetic shift
471-
// Thus, we need some special luts to be used on the last block
472-
// (which has the sign big)
473-
let message_for_block_right_shift_signed =
474-
if T::IS_SIGNED && operation == BarrelShifterOperation::RightShift {
475-
let lut = self
476-
.key
477-
.generate_lookup_table_bivariate(|input, first_shift_block| {
478-
let shift_within_block = first_shift_block % message_bits_per_block;
479-
let shift_to_next_block = (first_shift_block / message_bits_per_block) % 2;
480-
481-
let sign_bit_pos = message_bits_per_block - 1;
482-
let sign_bit = (input >> sign_bit_pos) & 1;
483-
let padding_block = (self.message_modulus().0 - 1) * sign_bit;
484-
485-
if shift_to_next_block == 1 {
486-
padding_block
487-
} else {
488-
// Pad with sign bits to 'simulate' an arithmetic shift
489-
let input = (padding_block << message_bits_per_block) | input;
490-
(input >> shift_within_block) % self.message_modulus().0
491-
}
492-
});
493-
Some(lut)
494-
} else {
495-
None
496-
};
497-
498547
let message_for_next_block_right_shift_signed = if T::IS_SIGNED
499548
&& operation == BarrelShifterOperation::RightShift
500549
{
@@ -693,7 +742,8 @@ impl ServerKey {
693742
) where
694743
T: IntegerRadixCiphertext,
695744
{
696-
let num_blocks = shift.blocks.len();
745+
// What matters is the len of the ct to shift, not the `shift` len
746+
let num_blocks = ct.blocks().len();
697747
let message_bits_per_block = self.key.message_modulus.0.ilog2() as u64;
698748
let carry_bits_per_block = self.key.carry_modulus.0.ilog2() as u64;
699749
let total_nb_bits = message_bits_per_block * num_blocks as u64;

0 commit comments

Comments
 (0)