Skip to content

Commit 597c61b

Browse files
committed
chore(shortint): add tests for the KS32 AP
1 parent 8a26df9 commit 597c61b

File tree

6 files changed

+104
-24
lines changed

6 files changed

+104
-24
lines changed

scripts/test_filtering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def filter_shortint_tests(input_args):
192192
msg_carry_pairs.append((4, 4))
193193

194194
filter_expression = [
195-
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks_pbs.*/)"
195+
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)"
196196
for msg, carry in msg_carry_pairs
197197
]
198198
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")

tfhe/src/shortint/atomic_pattern/mod.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,3 +551,46 @@ impl From<KS32AtomicPatternServerKey> for AtomicPatternServerKey {
551551
Self::KeySwitch32(value)
552552
}
553553
}
554+
555+
#[cfg(test)]
556+
mod test {
557+
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
558+
use crate::shortint::{gen_keys, ServerKey};
559+
560+
use super::AtomicPatternServerKey;
561+
562+
// Test an implementation of the KS32 AP as a dynamic atomic pattern
563+
#[test]
564+
fn test_ks32_as_dyn_ap_ci_run_filter() {
565+
let (client_key, server_key) =
566+
gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
567+
568+
// Convert the static ks 32 server key into a dynamic one
569+
let AtomicPatternServerKey::KeySwitch32(ks32_key) = server_key.atomic_pattern else {
570+
panic!("We know from parameters that AP is KS32")
571+
};
572+
573+
let ap_key = AtomicPatternServerKey::Dynamic(Box::new(ks32_key));
574+
575+
// Re create the server key with the DAP
576+
let server_key = ServerKey::from_raw_parts(
577+
ap_key,
578+
server_key.message_modulus,
579+
server_key.carry_modulus,
580+
server_key.max_degree,
581+
server_key.max_noise_level,
582+
);
583+
584+
// Do some operation
585+
let msg1 = 1;
586+
let msg2 = 0;
587+
588+
let ct_1 = client_key.encrypt(msg1);
589+
let ct_2 = client_key.encrypt(msg2);
590+
591+
let ct_3 = server_key.add(&ct_1, &ct_2);
592+
593+
let output = client_key.decrypt(&ct_3);
594+
assert_eq!(output, 1);
595+
}
596+
}

tfhe/src/shortint/oprf.rs

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,11 @@ impl<AP: AtomicPattern> GenericServerKey<AP> {
207207

208208
#[cfg(test)]
209209
pub(crate) mod test {
210-
use crate::core_crypto::prelude::decrypt_lwe_ciphertext;
211-
use crate::shortint::oprf::create_random_from_seed_modulus_switched;
212-
use crate::shortint::{ClientKey, ServerKey};
210+
use crate::core_crypto::prelude::{decrypt_lwe_ciphertext, LweSecretKey};
211+
use crate::shortint::{ClientKey, ServerKey, ShortintParameterSet};
212+
213+
use super::*;
214+
213215
use rayon::prelude::*;
214216
use statrs::distribution::ContinuousCDF;
215217
use std::collections::HashMap;
@@ -222,22 +224,34 @@ pub(crate) mod test {
222224
#[test]
223225
fn oprf_compare_plain_ci_run_filter() {
224226
use crate::shortint::gen_keys;
227+
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
225228
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
229+
226230
let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
227231

228232
for seed in 0..1000 {
229-
oprf_compare_plain_from_seed(Seed(seed), &ck, &sk);
233+
oprf_compare_plain_from_seed::<u64>(Seed(seed), &ck, &sk);
234+
}
235+
236+
let (ck, sk) = gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
237+
238+
for seed in 0..1000 {
239+
oprf_compare_plain_from_seed::<u32>(Seed(seed), &ck, &sk);
230240
}
231241
}
232242

233-
fn oprf_compare_plain_from_seed(seed: Seed, ck: &ClientKey, sk: &ServerKey) {
243+
fn oprf_compare_plain_from_seed<Scalar: UnsignedInteger + CastFrom<u64> + CastInto<u64>>(
244+
seed: Seed,
245+
ck: &ClientKey,
246+
sk: &ServerKey,
247+
) {
234248
let params = ck.parameters;
235249

236250
let random_bits_count = 2;
237251

238252
let input_p = 2 * params.polynomial_size().0 as u64;
239253

240-
let log_input_p = input_p.ilog2();
254+
let log_input_p = input_p.ilog2() as usize;
241255

242256
let p_prime = 1 << random_bits_count;
243257

@@ -255,15 +269,24 @@ pub(crate) mod test {
255269
params
256270
.polynomial_size()
257271
.to_blind_rotation_input_modulus_log(),
258-
sk.ciphertext_modulus,
272+
CiphertextModulus::new_native(),
259273
);
260274

261-
let sk = ck.small_lwe_secret_key();
275+
let sk = LweSecretKey::from_container(
276+
ck.small_lwe_secret_key()
277+
.as_ref()
278+
.iter()
279+
.copied()
280+
.map(|x| Scalar::cast_from(x))
281+
.collect::<Vec<_>>(),
282+
);
262283

263-
let plain_prf_input = decrypt_lwe_ciphertext(&sk, &ct)
264-
.0
265-
.wrapping_add(1 << (64 - log_input_p - 1))
266-
>> (64 - log_input_p);
284+
let plain_prf_input = CastInto::<u64>::cast_into(
285+
decrypt_lwe_ciphertext(&sk, &ct)
286+
.0
287+
.wrapping_add(Scalar::ONE << (Scalar::BITS - log_input_p - 1))
288+
>> (Scalar::BITS - log_input_p),
289+
);
267290

268291
let half_negacyclic_part = |x| 2 * (x / poly_delta) + 1;
269292

@@ -296,20 +319,28 @@ pub(crate) mod test {
296319
let p_value_limit: f64 = 0.000_01;
297320

298321
use crate::shortint::gen_keys;
322+
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
299323
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
300-
let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
301324

302-
let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
303-
test_uniformity(sample_count, p_value_limit, distinct_values, f)
304-
};
325+
for params in [
326+
ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
327+
ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
328+
] {
329+
let (ck, sk) = gen_keys(params);
305330

306-
let random_bits_count = 2;
331+
let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
332+
test_uniformity(sample_count, p_value_limit, distinct_values, f)
333+
};
334+
335+
let random_bits_count = 2;
307336

308-
test_uniformity(1 << random_bits_count, &|seed| {
309-
let img = sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
337+
test_uniformity(1 << random_bits_count, &|seed| {
338+
let img =
339+
sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
310340

311-
ck.decrypt_message_and_carry(&img)
312-
});
341+
ck.decrypt_message_and_carry(&img)
342+
});
343+
}
313344
}
314345

315346
pub fn test_uniformity<F>(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F)

tfhe/src/shortint/parameters/aliases.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use current_params::multi_bit::tuniform::p_fail_2_minus_64::ks_pbs_gpu::{
4444
V1_1_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64,
4545
};
4646
use current_params::noise_squashing::p_fail_2_minus_128::V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
47+
4748
// Aliases
4849

4950
// Compute Gaussian

tfhe/src/shortint/parameters/test_params.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::current_params::*;
2-
use super::AtomicPatternParameters;
2+
use super::{AtomicPatternParameters, KeySwitch32PBSParameters};
33

44
use super::{
55
ClassicPBSParameters, CompactPublicKeyEncryptionParameters, CompressionParameters,
@@ -209,3 +209,7 @@ pub const TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: CompressionPa
209209
pub const TEST_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128:
210210
CompressionParameters =
211211
V1_1_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
212+
213+
// KS32 PBS AP
214+
pub const TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128: KeySwitch32PBSParameters =
215+
V1_1_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;

tfhe/src/shortint/server_key/tests/parameterized_test.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ macro_rules! create_parameterized_test{
5555
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
5656
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
5757
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
58-
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64
58+
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
59+
TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128
5960
});
6061
};
6162
}

0 commit comments

Comments
 (0)