Skip to content

Commit 78b8a44

Browse files
committed
feat(gpu): enables the user to perform computation on multi-gpu using a custom selection of GPUs
1 parent 7724b78 commit 78b8a44

21 files changed

+798
-623
lines changed

tfhe/src/core_crypto/gpu/mod.rs

+24-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ pub use entities::*;
1717
use std::ffi::c_void;
1818
use tfhe_cuda_backend::bindings::*;
1919
use tfhe_cuda_backend::cuda_bind::*;
20-
2120
pub struct CudaStreams {
2221
pub ptr: Vec<*mut c_void>,
2322
pub gpu_indexes: Vec<GpuIndex>,
@@ -43,6 +42,22 @@ impl CudaStreams {
4342
gpu_indexes,
4443
}
4544
}
45+
/// Create a new `CudaStreams` structure with the GPUs with id provided in a list
46+
pub fn new_multi_gpu_with_indexes(indexes: &[GpuIndex]) -> Self {
47+
let _gpu_count = setup_multi_gpu();
48+
49+
let mut gpu_indexes = Vec::with_capacity(indexes.len());
50+
let mut ptr_array = Vec::with_capacity(indexes.len());
51+
52+
for &i in indexes {
53+
ptr_array.push(unsafe { cuda_create_stream(i.get()) });
54+
gpu_indexes.push(i);
55+
}
56+
Self {
57+
ptr: ptr_array,
58+
gpu_indexes,
59+
}
60+
}
4661
/// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given
4762
/// as input
4863
pub fn new_single_gpu(gpu_index: GpuIndex) -> Self {
@@ -88,6 +103,14 @@ impl CudaStreams {
88103
}
89104
}
90105

106+
impl Clone for CudaStreams {
107+
fn clone(&self) -> Self {
108+
// The `new_multi_gpu_with_indexes()` function is used here to adapt to any specific type of
109+
// streams being cloned (single, multi, or custom)
110+
Self::new_multi_gpu_with_indexes(self.gpu_indexes.as_slice())
111+
}
112+
}
113+
91114
impl Drop for CudaStreams {
92115
fn drop(&mut self) {
93116
for (i, &s) in self.ptr.iter().enumerate() {

tfhe/src/high_level_api/array/gpu/booleans.rs

+46-59
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
156156
rhs: TensorSlice<'_, Self::Slice<'a>>,
157157
) -> Self::Owned {
158158
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
159-
with_thread_local_cuda_streams(|streams| {
160-
lhs.par_iter()
161-
.zip(rhs.par_iter())
162-
.map(|(lhs, rhs)| {
163-
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
164-
})
165-
.collect::<Vec<_>>()
166-
})
159+
let streams = &cuda_key.streams;
160+
lhs.par_iter()
161+
.zip(rhs.par_iter())
162+
.map(|(lhs, rhs)| {
163+
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
164+
})
165+
.collect::<Vec<_>>()
167166
}))
168167
}
169168

@@ -172,14 +171,13 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
172171
rhs: TensorSlice<'_, Self::Slice<'a>>,
173172
) -> Self::Owned {
174173
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
175-
with_thread_local_cuda_streams(|streams| {
176-
lhs.par_iter()
177-
.zip(rhs.par_iter())
178-
.map(|(lhs, rhs)| {
179-
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
180-
})
181-
.collect::<Vec<_>>()
182-
})
174+
let streams = &cuda_key.streams;
175+
lhs.par_iter()
176+
.zip(rhs.par_iter())
177+
.map(|(lhs, rhs)| {
178+
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
179+
})
180+
.collect::<Vec<_>>()
183181
}))
184182
}
185183

@@ -188,24 +186,22 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
188186
rhs: TensorSlice<'_, Self::Slice<'a>>,
189187
) -> Self::Owned {
190188
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
191-
with_thread_local_cuda_streams(|streams| {
192-
lhs.par_iter()
193-
.zip(rhs.par_iter())
194-
.map(|(lhs, rhs)| {
195-
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
196-
})
197-
.collect::<Vec<_>>()
198-
})
189+
let streams = &cuda_key.streams;
190+
lhs.par_iter()
191+
.zip(rhs.par_iter())
192+
.map(|(lhs, rhs)| {
193+
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
194+
})
195+
.collect::<Vec<_>>()
199196
}))
200197
}
201198

202199
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
203200
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
204-
with_thread_local_cuda_streams(|streams| {
205-
lhs.par_iter()
206-
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
207-
.collect::<Vec<_>>()
208-
})
201+
let streams = &cuda_key.streams;
202+
lhs.par_iter()
203+
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
204+
.collect::<Vec<_>>()
209205
}))
210206
}
211207
}
@@ -216,16 +212,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
216212
rhs: TensorSlice<'_, &'_ [bool]>,
217213
) -> Self::Owned {
218214
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
219-
with_thread_local_cuda_streams(|streams| {
220-
lhs.par_iter()
221-
.zip(rhs.par_iter().copied())
222-
.map(|(lhs, rhs)| {
223-
CudaBooleanBlock(
224-
cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams),
225-
)
226-
})
227-
.collect::<Vec<_>>()
228-
})
215+
let streams = &cuda_key.streams;
216+
lhs.par_iter()
217+
.zip(rhs.par_iter().copied())
218+
.map(|(lhs, rhs)| {
219+
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams))
220+
})
221+
.collect::<Vec<_>>()
229222
}))
230223
}
231224

@@ -234,16 +227,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
234227
rhs: TensorSlice<'_, &'_ [bool]>,
235228
) -> Self::Owned {
236229
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
237-
with_thread_local_cuda_streams(|streams| {
238-
lhs.par_iter()
239-
.zip(rhs.par_iter().copied())
240-
.map(|(lhs, rhs)| {
241-
CudaBooleanBlock(
242-
cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams),
243-
)
244-
})
245-
.collect::<Vec<_>>()
246-
})
230+
let streams = &cuda_key.streams;
231+
lhs.par_iter()
232+
.zip(rhs.par_iter().copied())
233+
.map(|(lhs, rhs)| {
234+
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams))
235+
})
236+
.collect::<Vec<_>>()
247237
}))
248238
}
249239

@@ -252,16 +242,13 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
252242
rhs: TensorSlice<'_, &'_ [bool]>,
253243
) -> Self::Owned {
254244
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
255-
with_thread_local_cuda_streams(|streams| {
256-
lhs.par_iter()
257-
.zip(rhs.par_iter().copied())
258-
.map(|(lhs, rhs)| {
259-
CudaBooleanBlock(
260-
cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams),
261-
)
262-
})
263-
.collect::<Vec<_>>()
264-
})
245+
let streams = &cuda_key.streams;
246+
lhs.par_iter()
247+
.zip(rhs.par_iter().copied())
248+
.map(|(lhs, rhs)| {
249+
CudaBooleanBlock(cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams))
250+
})
251+
.collect::<Vec<_>>()
265252
}))
266253
}
267254
}

tfhe/src/high_level_api/array/gpu/integers.rs

+14-17
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,11 @@ where
108108
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, &T, &CudaStreams) -> T,
109109
{
110110
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
111-
with_thread_local_cuda_streams(|streams| {
112-
lhs.par_iter()
113-
.zip(rhs.par_iter())
114-
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
115-
.collect::<Vec<_>>()
116-
})
111+
let streams = &cuda_key.streams;
112+
lhs.par_iter()
113+
.zip(rhs.par_iter())
114+
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
115+
.collect::<Vec<_>>()
117116
}))
118117
}
119118

@@ -170,12 +169,11 @@ where
170169
F: Send + Sync + Fn(&crate::integer::gpu::CudaServerKey, &T, Clear, &CudaStreams) -> T,
171170
{
172171
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
173-
with_thread_local_cuda_streams(|streams| {
174-
lhs.par_iter()
175-
.zip(rhs.par_iter())
176-
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
177-
.collect::<Vec<_>>()
178-
})
172+
let streams = &cuda_key.streams;
173+
lhs.par_iter()
174+
.zip(rhs.par_iter())
175+
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
176+
.collect::<Vec<_>>()
179177
}))
180178
}
181179

@@ -336,11 +334,10 @@ where
336334

337335
fn bitnot(lhs: TensorSlice<'_, Self::Slice<'_>>) -> Self::Owned {
338336
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
339-
with_thread_local_cuda_streams(|streams| {
340-
lhs.par_iter()
341-
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
342-
.collect::<Vec<_>>()
343-
})
337+
let streams = &cuda_key.streams;
338+
lhs.par_iter()
339+
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
340+
.collect::<Vec<_>>()
344341
}))
345342
}
346343
}

0 commit comments

Comments
 (0)