Skip to content

Commit 075e1b7

Browse files
committed
refactor(integer): cover more cases for sanitization during expansion
1 parent 1a3b2d7 commit 075e1b7

File tree

1 file changed

+57
-37
lines changed

1 file changed

+57
-37
lines changed

tfhe/src/integer/ciphertext/compact_list.rs

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,12 @@ fn unpack_and_sanitize_message_and_carries(
113113

114114
/// This function sanitizes boolean blocks to make sure they encrypt a 0 or a 1
115115
fn sanitize_boolean_blocks(
116-
packed_blocks: Vec<Ciphertext>,
116+
expanded_blocks: Vec<Ciphertext>,
117117
sks: &ServerKey,
118118
infos: &[DataKind],
119119
) -> Vec<Ciphertext> {
120120
let message_modulus = sks.message_modulus().0;
121+
let msg_extract = sks.key.generate_lookup_table(|x: u64| x % message_modulus);
121122
let msg_extract_bool = sks.key.generate_lookup_table(|x: u64| {
122123
let tmp = x % message_modulus;
123124
if tmp == 0 {
@@ -138,15 +139,15 @@ fn sanitize_boolean_blocks(
138139
let acc = if matches!(data_kind, DataKind::Boolean) {
139140
Some(&msg_extract_bool)
140141
} else {
141-
None
142+
Some(&msg_extract)
142143
};
143144

144145
functions[overall_block_idx] = acc;
145146
overall_block_idx += 1;
146147
}
147148
}
148149

149-
packed_blocks
150+
expanded_blocks
150151
.into_par_iter()
151152
.zip(functions.into_par_iter())
152153
.map(|(mut block, sanitize_acc)| {
@@ -479,7 +480,10 @@ impl IntegerUnpackingToShortintCastingModeHelper {
479480
}
480481
}
481482

482-
pub fn generate_function(&self, infos: &[DataKind]) -> CastingFunctionsOwned {
483+
pub fn generate_unpack_and_sanitize_functions(
484+
&self,
485+
infos: &[DataKind],
486+
) -> CastingFunctionsOwned {
483487
let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum();
484488
let packed_block_count = block_count.div_ceil(2);
485489
let mut functions = vec![Some(Vec::with_capacity(2)); packed_block_count];
@@ -515,6 +519,30 @@ impl IntegerUnpackingToShortintCastingModeHelper {
515519

516520
functions
517521
}
522+
523+
pub fn generate_sanitize_without_unpacking_functions(
524+
&self,
525+
infos: &[DataKind],
526+
) -> CastingFunctionsOwned {
527+
let total_block_count: usize = infos.iter().map(|x| x.num_blocks()).sum();
528+
let mut functions = Vec::with_capacity(total_block_count);
529+
530+
for data_kind in infos {
531+
let block_count = data_kind.num_blocks();
532+
for _ in 0..block_count {
533+
let sanitize_function: &(dyn Fn(u64) -> u64 + Sync) =
534+
if matches!(data_kind, DataKind::Boolean) {
535+
self.msg_extract_bool.as_ref()
536+
} else {
537+
self.msg_extract.as_ref()
538+
};
539+
540+
functions.push(Some(vec![sanitize_function]));
541+
}
542+
}
543+
544+
functions
545+
}
518546
}
519547

520548
impl CompactCiphertextList {
@@ -681,23 +709,21 @@ impl CompactCiphertextList {
681709
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
682710
key_switching_key_view,
683711
) => {
684-
let function_helper;
685-
let functions;
712+
let dest_sks = &key_switching_key_view.key.dest_server_key;
713+
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
714+
dest_sks.message_modulus,
715+
dest_sks.carry_modulus,
716+
);
686717
let functions = if is_packed {
687-
let dest_sks = &key_switching_key_view.key.dest_server_key;
688-
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
689-
dest_sks.message_modulus,
690-
dest_sks.carry_modulus,
691-
);
692-
functions = function_helper.generate_function(&self.info);
693-
Some(functions.as_slice())
718+
function_helper.generate_unpack_and_sanitize_functions(&self.info)
694719
} else {
695-
None
720+
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
696721
};
722+
697723
self.ct_list
698724
.expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary {
699725
casting_key: key_switching_key_view.key,
700-
functions,
726+
functions: Some(functions.as_slice()),
701727
})?
702728
}
703729
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
@@ -811,26 +837,23 @@ impl ProvenCompactCiphertextList {
811837
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
812838
key_switching_key_view,
813839
) => {
814-
let function_helper;
815-
let functions;
840+
let dest_sks = &key_switching_key_view.key.dest_server_key;
841+
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
842+
dest_sks.message_modulus,
843+
dest_sks.carry_modulus,
844+
);
816845
let functions = if is_packed {
817-
let dest_sks = &key_switching_key_view.key.dest_server_key;
818-
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
819-
dest_sks.message_modulus,
820-
dest_sks.carry_modulus,
821-
);
822-
functions = function_helper.generate_function(&self.info);
823-
Some(functions.as_slice())
846+
function_helper.generate_unpack_and_sanitize_functions(&self.info)
824847
} else {
825-
None
848+
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
826849
};
827850
self.ct_list.verify_and_expand(
828851
crs,
829852
&public_key.key,
830853
metadata,
831854
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
832855
casting_key: key_switching_key_view.key,
833-
functions,
856+
functions: Some(functions.as_slice()),
834857
},
835858
)?
836859
}
@@ -902,23 +925,20 @@ impl ProvenCompactCiphertextList {
902925
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
903926
key_switching_key_view,
904927
) => {
905-
let function_helper;
906-
let functions;
928+
let dest_sks = &key_switching_key_view.key.dest_server_key;
929+
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
930+
dest_sks.message_modulus,
931+
dest_sks.carry_modulus,
932+
);
907933
let functions = if is_packed {
908-
let dest_sks = &key_switching_key_view.key.dest_server_key;
909-
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
910-
dest_sks.message_modulus,
911-
dest_sks.carry_modulus,
912-
);
913-
functions = function_helper.generate_function(&self.info);
914-
Some(functions.as_slice())
934+
function_helper.generate_unpack_and_sanitize_functions(&self.info)
915935
} else {
916-
None
936+
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
917937
};
918938
self.ct_list.expand_without_verification(
919939
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
920940
casting_key: key_switching_key_view.key,
921-
functions,
941+
functions: Some(functions.as_slice()),
922942
},
923943
)?
924944
}

0 commit comments

Comments
 (0)