Skip to content

Commit 84c93e9

Browse files
committed
chore(gpu): add vectorized function for bitand and plug it in the Array API
1 parent ffe9607 commit 84c93e9

File tree

6 files changed

+397
-21
lines changed

6 files changed

+397
-21
lines changed

backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cuh

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,43 @@ __host__ void host_integer_radix_bitop_kb(
2525
lwe_array_out->num_radix_blocks == lwe_array_2->num_radix_blocks,
2626
"Cuda error: input and output num radix blocks must be equal");
2727

28-
PANIC_IF_FALSE(
29-
lwe_array_out->num_radix_ciphertexts == lwe_array_1->num_radix_ciphertexts &&
30-
lwe_array_out->num_radix_ciphertexts == lwe_array_2->num_radix_ciphertexts,
31-
"Cuda error: input and output num radix ciphertexts must be equal");
28+
PANIC_IF_FALSE(
29+
lwe_array_out->num_radix_ciphertexts ==
30+
lwe_array_1->num_radix_ciphertexts &&
31+
lwe_array_out->num_radix_ciphertexts ==
32+
lwe_array_2->num_radix_ciphertexts,
33+
"Cuda error: input and output num radix ciphertexts must be equal");
3234

33-
PANIC_IF_FALSE(lwe_array_out->lwe_dimension == lwe_array_1->lwe_dimension &&
35+
PANIC_IF_FALSE(lwe_array_out->lwe_dimension == lwe_array_1->lwe_dimension &&
3436
lwe_array_out->lwe_dimension == lwe_array_2->lwe_dimension,
3537
"Cuda error: input and output lwe dimension must be equal");
3638

3739
auto lut = mem_ptr->lut;
38-
uint64_t degrees[lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts];
40+
uint64_t degrees[lwe_array_1->num_radix_blocks *
41+
lwe_array_1->num_radix_ciphertexts];
3942
if (mem_ptr->op == BITOP_TYPE::BITAND) {
40-
update_degrees_after_bitand(degrees, lwe_array_1->degrees,
41-
lwe_array_2->degrees,
42-
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
43+
update_degrees_after_bitand(
44+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
45+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
4346
} else if (mem_ptr->op == BITOP_TYPE::BITOR) {
44-
update_degrees_after_bitor(degrees, lwe_array_1->degrees,
45-
lwe_array_2->degrees,
46-
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
47+
update_degrees_after_bitor(
48+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
49+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
4750
} else if (mem_ptr->op == BITOP_TYPE::BITXOR) {
48-
update_degrees_after_bitxor(degrees, lwe_array_1->degrees,
49-
lwe_array_2->degrees,
50-
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
51+
update_degrees_after_bitxor(
52+
degrees, lwe_array_1->degrees, lwe_array_2->degrees,
53+
lwe_array_1->num_radix_blocks * lwe_array_1->num_radix_ciphertexts);
5154
}
5255

5356
integer_radix_apply_bivariate_lookup_table_kb<Torus>(
5457
streams, gpu_indexes, gpu_count, lwe_array_out, lwe_array_1, lwe_array_2,
55-
bsks, ksks, ms_noise_reduction_key, lut, lwe_array_out->num_radix_blocks * lwe_array_out->num_radix_ciphertexts,
58+
bsks, ksks, ms_noise_reduction_key, lut,
59+
lwe_array_out->num_radix_blocks * lwe_array_out->num_radix_ciphertexts,
5660
lut->params.message_modulus);
5761

5862
memcpy(lwe_array_out->degrees, degrees,
59-
lwe_array_out->num_radix_blocks * lwe_array_out->num_radix_ciphertexts * sizeof(uint64_t));
63+
lwe_array_out->num_radix_blocks *
64+
lwe_array_out->num_radix_ciphertexts * sizeof(uint64_t));
6065
}
6166

6267
template <typename Torus>

tfhe/src/high_level_api/array/gpu/integers.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ use crate::integer::block_decomposition::{
1919
DecomposableInto, RecomposableFrom, RecomposableSignedInteger,
2020
};
2121
use crate::integer::gpu::ciphertext::{
22-
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
22+
CudaIntegerRadixCiphertext, CudaRadixCiphertext, CudaSignedRadixCiphertext,
23+
CudaUnsignedRadixCiphertext,
2324
};
2425
use crate::integer::server_key::radix_parallel::scalar_div_mod::SignedReciprocable;
2526
use crate::integer::server_key::{Reciprocable, ScalarMultiplier};
@@ -83,6 +84,12 @@ impl<'a, T> TensorSlice<'a, GpuSlice<'a, T>> {
8384
pub fn par_iter(self) -> ParStridedIter<'a, T> {
8485
ParStridedIter::new(self.slice.0, self.dims.clone())
8586
}
87+
pub fn len(&self) -> usize {
88+
self.dims.flattened_len()
89+
}
90+
pub fn as_slice(&self) -> &'a [T] {
91+
self.slice.0
92+
}
8693
}
8794

8895
impl<'a, T> TensorSlice<'a, GpuSliceMut<'a, T>> {
@@ -316,7 +323,25 @@ where
316323
lhs: TensorSlice<'_, Self::Slice<'a>>,
317324
rhs: TensorSlice<'_, Self::Slice<'a>>,
318325
) -> Self::Owned {
319-
par_map_sks_op_on_pair_of_elements(lhs, rhs, crate::integer::gpu::CudaServerKey::bitand)
326+
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
327+
let streams = &cuda_key.streams;
328+
let num_ciphertexts = lhs.len() as u32;
329+
let lhs_slice: &[T] = lhs.as_slice();
330+
let rhs_slice: &[T] = rhs.as_slice();
331+
let mut lhs_aligned = T::from(CudaRadixCiphertext::from_radix_ciphertext_vec(
332+
lhs_slice, streams,
333+
));
334+
let rhs_aligned = T::from(CudaRadixCiphertext::from_radix_ciphertext_vec(
335+
rhs_slice, streams,
336+
));
337+
crate::integer::gpu::CudaServerKey::bitand_vec(
338+
cuda_key.pbs_key(),
339+
&mut lhs_aligned,
340+
&rhs_aligned,
341+
num_ciphertexts,
342+
streams,
343+
)
344+
}))
320345
}
321346

322347
fn bitor<'a>(

tfhe/src/high_level_api/array/traits.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ impl<'a, T> TensorSlice<'a, &'a [T]> {
2828
pub fn par_iter(self) -> ParStridedIter<'a, T> {
2929
ParStridedIter::new(self.slice, self.dims.clone())
3030
}
31+
pub fn len(&self) -> usize {
32+
self.dims.flattened_len()
33+
}
34+
pub fn as_slice(&self) -> &'a [T] {
35+
self.slice
36+
}
3137
}
3238

3339
impl<'a, T> TensorSlice<'a, &'a mut [T]> {

tfhe/src/integer/gpu/ciphertext/mod.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ pub mod squashed_noise;
77

88
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
99
use crate::core_crypto::gpu::vec::CudaVec;
10-
use crate::core_crypto::gpu::CudaStreams;
10+
use crate::core_crypto::gpu::{CudaLweList, CudaStreams};
1111
use crate::core_crypto::prelude::{LweCiphertextList, LweCiphertextOwned};
1212
use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo};
1313
use crate::integer::parameters::LweDimension;
1414
use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
1515
use crate::shortint::{Ciphertext, EncryptionKeyChoice};
1616
use crate::GpuIndex;
1717

18+
use crate::shortint::parameters::LweCiphertextCount;
1819
pub use compressed_noise_squashed_ciphertext_list::*;
1920

2021
pub trait CudaIntegerRadixCiphertext: Sized {
@@ -70,6 +71,60 @@ pub trait CudaIntegerRadixCiphertext: Sized {
7071
fn gpu_indexes(&self) -> &[GpuIndex] {
7172
&self.as_ref().d_blocks.0.d_vec.gpu_indexes
7273
}
74+
75+
fn to_integer_radix_ciphertext_vec(
76+
&self,
77+
num_radix_ciphertexts: u32,
78+
streams: &CudaStreams,
79+
) -> Vec<Self> {
80+
let total_blocks = self.as_ref().d_blocks.0.lwe_ciphertext_count.0;
81+
assert_eq!(total_blocks % num_radix_ciphertexts as usize, 0, "Total number of blocks ({total_blocks}) is not divisible by number of radix ciphertexts ({num_radix_ciphertexts})");
82+
83+
let num_blocks = total_blocks / num_radix_ciphertexts as usize;
84+
85+
let mut result = Vec::with_capacity(num_radix_ciphertexts as usize);
86+
let lwe_dimension = self.as_ref().d_blocks.lwe_dimension();
87+
88+
for i in 0..num_radix_ciphertexts as usize {
89+
let block_start = i * num_blocks;
90+
let block_end = block_start + num_blocks;
91+
92+
let d_vec = unsafe {
93+
let mut d_vec =
94+
CudaVec::new_async(lwe_dimension.to_lwe_size().0 * num_blocks, streams, 0);
95+
96+
let copy_start = block_start * lwe_dimension.to_lwe_size().0;
97+
let copy_end = (block_end + 1) * lwe_dimension.to_lwe_size().0;
98+
d_vec.copy_src_range_gpu_to_gpu_async(
99+
copy_start..copy_end,
100+
&self.as_ref().d_blocks.0.d_vec,
101+
streams,
102+
0,
103+
);
104+
105+
streams.synchronize();
106+
d_vec
107+
};
108+
let lwe_list = CudaLweList::<u64> {
109+
d_vec,
110+
lwe_ciphertext_count: LweCiphertextCount(num_blocks),
111+
lwe_dimension,
112+
ciphertext_modulus: self.as_ref().d_blocks.ciphertext_modulus(),
113+
};
114+
115+
// Copy the associated block metadata
116+
let block_info = self.as_ref().info.blocks[block_start..block_end].to_vec();
117+
118+
let info = CudaRadixCiphertextInfo { blocks: block_info };
119+
120+
result.push(Self::from(CudaRadixCiphertext::new(
121+
CudaLweCiphertextList(lwe_list),
122+
info,
123+
)));
124+
}
125+
126+
result
127+
}
73128
}
74129

75130
pub struct CudaRadixCiphertext {

tfhe/src/integer/gpu/mod.rs

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8134,3 +8134,150 @@ pub unsafe fn expand_async<T: UnsignedInteger, B: Numeric>(
81348134
std::ptr::addr_of_mut!(mem_ptr),
81358135
);
81368136
}
8137+
8138+
#[allow(clippy::too_many_arguments)]
8139+
/// # Safety
8140+
///
8141+
/// This operation modifies raw GPU pointers on the GPU
8142+
pub unsafe fn unchecked_bitop_vec_radix_kb_assign<T: UnsignedInteger, B: Numeric>(
8143+
streams: &CudaStreams,
8144+
radix_lwe_left: &mut CudaRadixCiphertext,
8145+
radix_lwe_right: &CudaRadixCiphertext,
8146+
bootstrapping_key: &CudaVec<B>,
8147+
keyswitch_key: &CudaVec<T>,
8148+
message_modulus: MessageModulus,
8149+
carry_modulus: CarryModulus,
8150+
glwe_dimension: GlweDimension,
8151+
polynomial_size: PolynomialSize,
8152+
big_lwe_dimension: LweDimension,
8153+
small_lwe_dimension: LweDimension,
8154+
ks_level: DecompositionLevelCount,
8155+
ks_base_log: DecompositionBaseLog,
8156+
pbs_level: DecompositionLevelCount,
8157+
pbs_base_log: DecompositionBaseLog,
8158+
op: BitOpType,
8159+
num_blocks: u32,
8160+
num_radix_ciphertexts: u32,
8161+
pbs_type: PBSType,
8162+
grouping_factor: LweBskGroupingFactor,
8163+
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
8164+
) {
8165+
assert_eq!(
8166+
streams.gpu_indexes[0],
8167+
radix_lwe_left.d_blocks.0.d_vec.gpu_index(0),
8168+
"GPU error: first stream is on GPU {}, first lhs pointer is on GPU {}",
8169+
streams.gpu_indexes[0].get(),
8170+
radix_lwe_left.d_blocks.0.d_vec.gpu_index(0).get(),
8171+
);
8172+
assert_eq!(
8173+
streams.gpu_indexes[0],
8174+
radix_lwe_right.d_blocks.0.d_vec.gpu_index(0),
8175+
"GPU error: first stream is on GPU {}, first rhs pointer is on GPU {}",
8176+
streams.gpu_indexes[0].get(),
8177+
radix_lwe_right.d_blocks.0.d_vec.gpu_index(0).get(),
8178+
);
8179+
assert_eq!(
8180+
streams.gpu_indexes[0],
8181+
bootstrapping_key.gpu_index(0),
8182+
"GPU error: first stream is on GPU {}, first bsk pointer is on GPU {}",
8183+
streams.gpu_indexes[0].get(),
8184+
bootstrapping_key.gpu_index(0).get(),
8185+
);
8186+
assert_eq!(
8187+
streams.gpu_indexes[0],
8188+
keyswitch_key.gpu_index(0),
8189+
"GPU error: first stream is on GPU {}, first ksk pointer is on GPU {}",
8190+
streams.gpu_indexes[0].get(),
8191+
keyswitch_key.gpu_index(0).get(),
8192+
);
8193+
let ct_modulus = radix_lwe_left
8194+
.d_blocks
8195+
.ciphertext_modulus()
8196+
.raw_modulus_float();
8197+
let (noise_reduction_type, ms_noise_reduction_key_ffi) =
8198+
resolve_ms_noise_reduction_config(ms_noise_reduction_configuration, ct_modulus);
8199+
8200+
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
8201+
let mut radix_lwe_left_degrees = radix_lwe_left
8202+
.info
8203+
.blocks
8204+
.iter()
8205+
.map(|b| b.degree.0)
8206+
.collect();
8207+
let mut radix_lwe_left_noise_levels = radix_lwe_left
8208+
.info
8209+
.blocks
8210+
.iter()
8211+
.map(|b| b.noise_level.0)
8212+
.collect();
8213+
let mut cuda_ffi_radix_lwe_left = prepare_cuda_radix_ffi(
8214+
radix_lwe_left,
8215+
&mut radix_lwe_left_degrees,
8216+
&mut radix_lwe_left_noise_levels,
8217+
);
8218+
// Here even though the input is not modified, data is passed as mutable.
8219+
// This avoids having to create two structs for the CudaRadixCiphertext pointers,
8220+
// one const and the other mutable.
8221+
// Having two structs on the Cuda side complicates things as we need to be sure we pass the
8222+
// Const structure as input instead of the mutable structure, which leads to complicated
8223+
// data manipulation on the C++ side to change mutability of data.
8224+
let mut radix_lwe_right_degrees = radix_lwe_right
8225+
.info
8226+
.blocks
8227+
.iter()
8228+
.map(|b| b.degree.0)
8229+
.collect();
8230+
let mut radix_lwe_right_noise_levels = radix_lwe_right
8231+
.info
8232+
.blocks
8233+
.iter()
8234+
.map(|b| b.noise_level.0)
8235+
.collect();
8236+
let cuda_ffi_radix_lwe_right = prepare_cuda_radix_ffi(
8237+
radix_lwe_right,
8238+
&mut radix_lwe_right_degrees,
8239+
&mut radix_lwe_right_noise_levels,
8240+
);
8241+
scratch_cuda_integer_radix_bitop_kb_64(
8242+
streams.ptr.as_ptr(),
8243+
streams.gpu_indexes_ptr(),
8244+
streams.len() as u32,
8245+
std::ptr::addr_of_mut!(mem_ptr),
8246+
glwe_dimension.0 as u32,
8247+
polynomial_size.0 as u32,
8248+
big_lwe_dimension.0 as u32,
8249+
small_lwe_dimension.0 as u32,
8250+
ks_level.0 as u32,
8251+
ks_base_log.0 as u32,
8252+
pbs_level.0 as u32,
8253+
pbs_base_log.0 as u32,
8254+
grouping_factor.0 as u32,
8255+
num_blocks * num_radix_ciphertexts,
8256+
message_modulus.0 as u32,
8257+
carry_modulus.0 as u32,
8258+
pbs_type as u32,
8259+
op as u32,
8260+
true,
8261+
noise_reduction_type as u32,
8262+
);
8263+
cuda_bitop_integer_radix_ciphertext_kb_64(
8264+
streams.ptr.as_ptr(),
8265+
streams.gpu_indexes_ptr(),
8266+
streams.len() as u32,
8267+
&raw mut cuda_ffi_radix_lwe_left,
8268+
&raw const cuda_ffi_radix_lwe_left,
8269+
&raw const cuda_ffi_radix_lwe_right,
8270+
mem_ptr,
8271+
bootstrapping_key.ptr.as_ptr(),
8272+
keyswitch_key.ptr.as_ptr(),
8273+
&raw const ms_noise_reduction_key_ffi,
8274+
);
8275+
cleanup_cuda_integer_bitop(
8276+
streams.ptr.as_ptr(),
8277+
streams.gpu_indexes_ptr(),
8278+
streams.len() as u32,
8279+
std::ptr::addr_of_mut!(mem_ptr),
8280+
);
8281+
update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left);
8282+
streams.synchronize();
8283+
}

0 commit comments

Comments
 (0)