Skip to content

Commit 2f295ea

Browse files
committed
fix(integer): fix unsigned_overflowing_sub on trivials
unsigned_overflowing_sub does an independant subtraction on each blocks with a correcting term being added to avoid trashing the padding bit (lhs - rhs + correction). The correction depended on rhs's degree. e.g. if rhs's degree was in range 1..(msg_mod-1) -> correction = msg_mod However if rhs's degree was zero (so rhs is a trivial 0), the correction was also 0, however the borrow propagation rely on that correction to always be added.
1 parent fa54a02 commit 2f295ea

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

tfhe/src/integer/server_key/radix/sub.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,17 @@ impl ServerKey {
416416
let mut borrow = self.key.create_trivial(0);
417417
let mut new_blocks = Vec::with_capacity(lhs.blocks.len());
418418
for (lhs_b, rhs_b) in lhs.blocks.iter().zip(rhs.blocks.iter()) {
419-
let mut result_block = self.key.unchecked_sub(lhs_b, rhs_b);
419+
let (mut result_block, correction) =
420+
self.key.unchecked_sub_with_correcting_term(lhs_b, rhs_b);
421+
if correction == 0 {
422+
// When rhs_block is a trivial zero, the correcting term added is 0
423+
// However we rely on that correcting term to be added regardless
424+
assert_eq!(rhs_b.degree.0, 0);
425+
self.key.unchecked_scalar_add_assign(
426+
&mut result_block,
427+
self.key.message_modulus.0 as u8,
428+
);
429+
}
420430
// Here unchecked_sub_assign does not give correct result, we don't want
421431
// the correcting term to be used
422432
// -> This is ok as the value returned by unchecked_sub is in range 1..(message_mod * 2)

tfhe/src/integer/server_key/radix_parallel/sub.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,21 @@ impl ServerKey {
361361
.blocks
362362
.iter()
363363
.zip(rhs.blocks.iter())
364-
.map(|(lhs_block, rhs_block)| self.key.unchecked_sub(lhs_block, rhs_block))
364+
.map(|(lhs_block, rhs_block)| {
365+
let (mut result_block, correction) = self
366+
.key
367+
.unchecked_sub_with_correcting_term(lhs_block, rhs_block);
368+
if correction == 0 {
369+
// When rhs_block is a trivial zero, the correcting term added is 0
370+
// However we rely on that correcting term to be added regardless
371+
assert_eq!(rhs_block.degree.0, 0);
372+
self.key.unchecked_scalar_add_assign(
373+
&mut result_block,
374+
self.key.message_modulus.0 as u8,
375+
);
376+
}
377+
result_block
378+
})
365379
.collect::<Vec<_>>();
366380
let mut ct = RadixCiphertext::from(ct);
367381

tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,6 +1822,40 @@ where
18221822
);
18231823
}
18241824
}
1825+
1826+
// Test with trivial inputs, as it was bugged at some point
1827+
for _ in 0..4 {
1828+
// Reduce maximum value of random number such that at least the last block is a trivial 0
1829+
// (This is how the reproducing case was found)
1830+
let clear_0 = rng.gen::<u64>() % (modulus / sks.key.message_modulus.0 as u64);
1831+
let clear_1 = rng.gen::<u64>() % (modulus / sks.key.message_modulus.0 as u64);
1832+
1833+
let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
1834+
let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
1835+
1836+
assert_eq!(a.blocks[NB_CTXT - 1].degree.0, 0);
1837+
assert_eq!(b.blocks[NB_CTXT - 1].degree.0, 0);
1838+
1839+
let (encrypted_result, encrypted_overflow) =
1840+
sks.unchecked_unsigned_overflowing_sub_parallelized(&a, &b);
1841+
1842+
let (expected_result, expected_overflowed) =
1843+
overflowing_sub_under_modulus(clear_0, clear_1, modulus);
1844+
1845+
let decrypted_result: u64 = cks.decrypt(&encrypted_result);
1846+
let decrypted_overflowed = cks.decrypt_one_block(&encrypted_overflow) == 1;
1847+
assert_eq!(
1848+
decrypted_result, expected_result,
1849+
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
1850+
expected {expected_result}, got {decrypted_result}"
1851+
);
1852+
assert_eq!(
1853+
decrypted_overflowed,
1854+
expected_overflowed,
1855+
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
1856+
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
1857+
);
1858+
}
18251859
}
18261860

18271861
pub(crate) fn default_sub_test<P, T>(param: P, mut executor: T)

0 commit comments

Comments
 (0)