@@ -33,6 +33,15 @@ impl From<ServerKey> for crate::shortint::ServerKey {
33
33
}
34
34
}
35
35
36
+ /// Compute the [`MaxDegree`] for an integer server key (compressed or uncompressed). This formula
37
+ /// provisions a free carry bit. This allows carry propagation between shortint blocks in a
38
+ /// [`RadixCiphertext`](`crate::integer::RadixCiphertext`), as that process requires adding a bit of
39
+ /// carry from one shortint block to the next, which would overflow and lead to wrong results if we
40
+ /// did not provision that carry bit.
41
+ fn integer_server_key_max_degree ( parameters : crate :: shortint:: ShortintParameterSet ) -> MaxDegree {
42
+ MaxDegree ( ( parameters. message_modulus ( ) . 0 - 1 ) * parameters. carry_modulus ( ) . 0 - 1 )
43
+ }
44
+
36
45
impl ServerKey {
37
46
/// Generates a server key.
38
47
///
@@ -54,13 +63,11 @@ impl ServerKey {
54
63
{
55
64
// It should remain just enough space to add a carry
56
65
let client_key = cks. as_ref ( ) ;
57
- let max = ( client_key. key . parameters . message_modulus ( ) . 0 - 1 )
58
- * client_key. key . parameters . carry_modulus ( ) . 0
59
- - 1 ;
66
+ let max_degree = integer_server_key_max_degree ( client_key. key . parameters ) ;
60
67
61
68
let sks = crate :: shortint:: server_key:: ServerKey :: new_with_max_degree (
62
69
& client_key. key ,
63
- MaxDegree ( max ) ,
70
+ max_degree ,
64
71
) ;
65
72
66
73
ServerKey { key : sks }
@@ -87,10 +94,9 @@ impl ServerKey {
87
94
mut key : crate :: shortint:: server_key:: ServerKey ,
88
95
) -> ServerKey {
89
96
// It should remain just enough space add a carry
90
- let max =
91
- ( cks. key . parameters . message_modulus ( ) . 0 - 1 ) * cks. key . parameters . carry_modulus ( ) . 0 - 1 ;
97
+ let max_degree = integer_server_key_max_degree ( cks. key . parameters ) ;
92
98
93
- key. max_degree = MaxDegree ( max ) ;
99
+ key. max_degree = max_degree ;
94
100
ServerKey { key }
95
101
}
96
102
@@ -111,7 +117,10 @@ pub struct CompressedServerKey {
111
117
112
118
impl CompressedServerKey {
113
119
pub fn new ( client_key : & ClientKey ) -> CompressedServerKey {
114
- let key = crate :: shortint:: CompressedServerKey :: new ( & client_key. key ) ;
120
+ let max_degree = integer_server_key_max_degree ( client_key. key . parameters ) ;
121
+
122
+ let key =
123
+ crate :: shortint:: CompressedServerKey :: new_with_max_degree ( & client_key. key , max_degree) ;
115
124
Self { key }
116
125
}
117
126
}
@@ -122,3 +131,46 @@ impl From<CompressedServerKey> for ServerKey {
122
131
Self { key }
123
132
}
124
133
}
134
+
135
+ #[ cfg( test) ]
136
+ mod test {
137
+ use super :: * ;
138
+ use crate :: integer:: RadixClientKey ;
139
+ use crate :: shortint:: parameters:: PARAM_MESSAGE_2_CARRY_2 ;
140
+
141
+ /// https://github.com/zama-ai/tfhe-rs/issues/460
142
+ /// Problem with CompressedServerKey degree being set to shortint MaxDegree not accounting for
143
+ /// the necessary carry bits for e.g. Radix carry propagation.
144
+ #[ test]
145
+ fn test_compressed_server_key_max_degree ( ) {
146
+ let cks = ClientKey :: new ( crate :: shortint:: parameters:: PARAM_MESSAGE_2_CARRY_2_KS_PBS ) ;
147
+ // msg_mod = 4, carry_mod = 4, (msg_mod - 1) * carry_mod = 12; minus 1 => 11
148
+ let expected_max_degree = MaxDegree ( 11 ) ;
149
+
150
+ let sks = ServerKey :: new ( & cks) ;
151
+ assert_eq ! ( sks. key. max_degree, expected_max_degree) ;
152
+
153
+ let csks = CompressedServerKey :: new ( & cks) ;
154
+ assert_eq ! ( csks. key. max_degree, expected_max_degree) ;
155
+
156
+ let decompressed_sks: ServerKey = csks. into ( ) ;
157
+ assert_eq ! ( decompressed_sks. key. max_degree, expected_max_degree) ;
158
+
159
+ // Repro case from the user
160
+ {
161
+ let client_key = RadixClientKey :: new ( PARAM_MESSAGE_2_CARRY_2 , 14 ) ;
162
+ let compressed_eval_key = CompressedServerKey :: new ( client_key. as_ref ( ) ) ;
163
+ let evaluation_key = ServerKey :: from ( compressed_eval_key) ;
164
+ let modulus = ( client_key. parameters ( ) . message_modulus ( ) . 0 as u128 )
165
+ . pow ( client_key. num_blocks ( ) as u32 ) ;
166
+
167
+ let mut ct = client_key. encrypt ( modulus - 1 ) ;
168
+ let mut res_ct = ct. clone ( ) ;
169
+ for _ in 0 ..5 {
170
+ res_ct = evaluation_key. smart_add_parallelized ( & mut res_ct, & mut ct) ;
171
+ }
172
+ let res = client_key. decrypt :: < u128 > ( & res_ct) ;
173
+ assert_eq ! ( modulus - 6 , res) ;
174
+ }
175
+ }
176
+ }
0 commit comments