Skip to content

Commit 81223b0

Browse files
committed
fix: fix compression code for GPU which assumed a CPU data layout
- the CPU data layout is truncated to only store relevant bodies (i.e. emtpy bodies are assumed to be 0) but the GPU CUDA code manages full GLWEs only. To fix that we manage the data layout during conversions to have consistent behavior when copying the list to/from CPU/GPU. Compression code has been fixed on the CPU side to have the proper length for the output expected by the CUDA code
1 parent 726a771 commit 81223b0

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::core_crypto::entities::packed_integers::PackedIntegers;
22
use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex};
33
use crate::core_crypto::gpu::CudaStreams;
44
use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext;
5-
use crate::core_crypto::prelude::{CiphertextCount, LweCiphertextCount};
5+
use crate::core_crypto::prelude::{glwe_ciphertext_size, CiphertextCount, LweCiphertextCount};
66
use crate::integer::ciphertext::{CompressedCiphertextList, DataKind};
77
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
88
use crate::integer::gpu::ciphertext::{
@@ -326,11 +326,25 @@ impl CompressedCiphertextList {
326326
let message_modulus = self.packed_list.message_modulus;
327327
let carry_modulus = self.packed_list.carry_modulus;
328328

329-
let flat_cpu_data = modulus_switched_glwe_ciphertext_list
329+
let mut flat_cpu_data = modulus_switched_glwe_ciphertext_list
330330
.iter()
331331
.flat_map(|ct| ct.packed_integers.packed_coeffs.clone())
332332
.collect_vec();
333333

334+
let glwe_ciphertext_count = self.packed_list.modulus_switched_glwe_ciphertext_list.len();
335+
let glwe_size = self.packed_list.modulus_switched_glwe_ciphertext_list[0]
336+
.glwe_dimension()
337+
.to_glwe_size();
338+
let polynomial_size =
339+
self.packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
340+
341+
// FIXME: have a more precise memory handling, this is too long and should be "just" the
342+
// original flat_cpu_data.len()
343+
let unpacked_glwe_ciphertext_flat_len =
344+
glwe_ciphertext_count * glwe_ciphertext_size(glwe_size, polynomial_size);
345+
346+
flat_cpu_data.resize(unpacked_glwe_ciphertext_flat_len, 0u64);
347+
334348
let flat_gpu_data = unsafe {
335349
let v = CudaVec::from_cpu_async(flat_cpu_data.as_slice(), streams, 0);
336350
streams.synchronize();

tfhe/src/integer/gpu/list_compression/server_keys.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
33
use crate::core_crypto::gpu::vec::CudaVec;
44
use crate::core_crypto::gpu::CudaStreams;
55
use crate::core_crypto::prelude::{
6-
glwe_ciphertext_size, glwe_mask_size, CiphertextModulus, CiphertextModulusLog,
7-
GlweCiphertextCount, LweCiphertextCount, PolynomialSize,
6+
glwe_ciphertext_size, CiphertextModulus, CiphertextModulusLog, GlweCiphertextCount,
7+
LweCiphertextCount, PolynomialSize,
88
};
99
use crate::integer::ciphertext::DataKind;
1010
use crate::integer::compression_keys::CompressionKey;
@@ -173,12 +173,12 @@ impl CudaCompressionKey {
173173
.sum();
174174

175175
let num_glwes = num_lwes.div_ceil(self.lwe_per_glwe.0);
176-
let glwe_mask_size = glwe_mask_size(
177-
compressed_glwe_size.to_glwe_dimension(),
178-
compressed_polynomial_size,
179-
);
176+
let glwe_ciphertext_size =
177+
glwe_ciphertext_size(compressed_glwe_size, compressed_polynomial_size);
180178
// The number of u64 (both mask and bodies)
181-
let uncompressed_len = num_glwes * glwe_mask_size + num_lwes;
179+
// FIXME: have a more precise memory handling, this is too long and should be
180+
// num_glwes * glwe_mask_size + num_lwes
181+
let uncompressed_len = num_glwes * glwe_ciphertext_size;
182182
let number_bits_to_pack = uncompressed_len * self.storage_log_modulus.0;
183183
let compressed_len = number_bits_to_pack.div_ceil(u64::BITS as usize);
184184
let mut packed_glwe_list = CudaVec::new(compressed_len, streams, 0);

0 commit comments

Comments
 (0)