Skip to content

Commit 98db328

Browse files
committed
fix(integer): set proper MaxDegree for CompressedServerKey
- add shortint API to generate a CompressedServerKey with MaxDegree - add non regression test based on the user issue - factorize MaxDegree computation for integer server keys
1 parent f5fab4d commit 98db328

File tree

2 files changed

+69
-8
lines changed

2 files changed

+69
-8
lines changed

tfhe/src/integer/server_key/mod.rs

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ impl From<ServerKey> for crate::shortint::ServerKey {
3333
}
3434
}
3535

36+
/// Compute the [`MaxDegree`] for an integer server key (compressed or uncompressed). This formula
37+
/// provisions a free carry bit. This allows carry propagation between shortint blocks in a
38+
/// [`RadixCiphertext`](`crate::integer::RadixCiphertext`), as that process requires adding a bit of
39+
/// carry from one shortint block to the next, which would overflow and lead to wrong results if we
40+
/// did not provision that carry bit.
41+
fn integer_server_key_max_degree(parameters: crate::shortint::ShortintParameterSet) -> MaxDegree {
42+
MaxDegree((parameters.message_modulus().0 - 1) * parameters.carry_modulus().0 - 1)
43+
}
44+
3645
impl ServerKey {
3746
/// Generates a server key.
3847
///
@@ -54,13 +63,11 @@ impl ServerKey {
5463
{
5564
// It should remain just enough space to add a carry
5665
let client_key = cks.as_ref();
57-
let max = (client_key.key.parameters.message_modulus().0 - 1)
58-
* client_key.key.parameters.carry_modulus().0
59-
- 1;
66+
let max_degree = integer_server_key_max_degree(client_key.key.parameters);
6067

6168
let sks = crate::shortint::server_key::ServerKey::new_with_max_degree(
6269
&client_key.key,
63-
MaxDegree(max),
70+
max_degree,
6471
);
6572

6673
ServerKey { key: sks }
@@ -87,10 +94,9 @@ impl ServerKey {
8794
mut key: crate::shortint::server_key::ServerKey,
8895
) -> ServerKey {
8996
// It should remain just enough space add a carry
90-
let max =
91-
(cks.key.parameters.message_modulus().0 - 1) * cks.key.parameters.carry_modulus().0 - 1;
97+
let max_degree = integer_server_key_max_degree(cks.key.parameters);
9298

93-
key.max_degree = MaxDegree(max);
99+
key.max_degree = max_degree;
94100
ServerKey { key }
95101
}
96102

@@ -111,7 +117,10 @@ pub struct CompressedServerKey {
111117

112118
impl CompressedServerKey {
113119
pub fn new(client_key: &ClientKey) -> CompressedServerKey {
114-
let key = crate::shortint::CompressedServerKey::new(&client_key.key);
120+
let max_degree = integer_server_key_max_degree(client_key.key.parameters);
121+
122+
let key =
123+
crate::shortint::CompressedServerKey::new_with_max_degree(&client_key.key, max_degree);
115124
Self { key }
116125
}
117126
}
@@ -122,3 +131,46 @@ impl From<CompressedServerKey> for ServerKey {
122131
Self { key }
123132
}
124133
}
134+
135+
#[cfg(test)]
136+
mod test {
137+
use super::*;
138+
use crate::integer::RadixClientKey;
139+
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
140+
141+
/// https://github.com/zama-ai/tfhe-rs/issues/460
142+
/// Problem with CompressedServerKey degree being set to shortint MaxDegree not accounting for
143+
/// the necessary carry bits for e.g. Radix carry propagation.
144+
#[test]
145+
fn test_compressed_server_key_max_degree() {
146+
let cks = ClientKey::new(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS);
147+
// msg_mod = 4, carry_mod = 4, (msg_mod - 1) * carry_mod = 12; minus 1 => 11
148+
let expected_max_degree = MaxDegree(11);
149+
150+
let sks = ServerKey::new(&cks);
151+
assert_eq!(sks.key.max_degree, expected_max_degree);
152+
153+
let csks = CompressedServerKey::new(&cks);
154+
assert_eq!(csks.key.max_degree, expected_max_degree);
155+
156+
let decompressed_sks: ServerKey = csks.into();
157+
assert_eq!(decompressed_sks.key.max_degree, expected_max_degree);
158+
159+
// Repro case from the user
160+
{
161+
let client_key = RadixClientKey::new(PARAM_MESSAGE_2_CARRY_2, 14);
162+
let compressed_eval_key = CompressedServerKey::new(client_key.as_ref());
163+
let evaluation_key = ServerKey::from(compressed_eval_key);
164+
let modulus = (client_key.parameters().message_modulus().0 as u128)
165+
.pow(client_key.num_blocks() as u32);
166+
167+
let mut ct = client_key.encrypt(modulus - 1);
168+
let mut res_ct = ct.clone();
169+
for _ in 0..5 {
170+
res_ct = evaluation_key.smart_add_parallelized(&mut res_ct, &mut ct);
171+
}
172+
let res = client_key.decrypt::<u128>(&res_ct);
173+
assert_eq!(modulus - 6, res);
174+
}
175+
}
176+
}

tfhe/src/shortint/server_key/compressed.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,13 @@ impl CompressedServerKey {
5454
engine.new_compressed_server_key(client_key).unwrap()
5555
})
5656
}
57+
58+
/// Generate a compressed server key with a chosen maximum degree
59+
pub fn new_with_max_degree(cks: &ClientKey, max_degree: MaxDegree) -> CompressedServerKey {
60+
ShortintEngine::with_thread_local_mut(|engine| {
61+
engine
62+
.new_compressed_server_key_with_max_degree(cks, max_degree)
63+
.unwrap()
64+
})
65+
}
5766
}

0 commit comments

Comments
 (0)