Skip to content

Commit 4f9aa76

Browse files
authored
Oblivious swap network (#1682)
Add oblivious swap network
1 parent 128211f commit 4f9aa76

File tree

4 files changed

+392
-8
lines changed

4 files changed

+392
-8
lines changed

iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs

Lines changed: 138 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@ use crate::{
33
hawkers::shared_irises::{SharedIrises, SharedIrisesRef},
44
hnsw::{vector_store::VectorStoreMut, VectorStore},
55
protocol::{
6-
ops::{batch_signed_lift_vec, cross_compare, galois_ring_to_rep3, lte_threshold_and_open},
6+
ops::{
7+
batch_signed_lift_vec, conditionally_swap_distances,
8+
conditionally_swap_distances_plain_ids, cross_compare, galois_ring_to_rep3,
9+
lte_threshold_and_open, oblivious_cross_compare,
10+
},
711
shared_iris::{ArcIris, GaloisRingSharedIris},
812
},
9-
shares::share::{DistanceShare, Share},
13+
shares::{
14+
bit::Bit,
15+
share::{DistanceShare, Share},
16+
},
1017
};
1118
use eyre::Result;
1219
use iris_mpc_common::vector_id::VectorId;
@@ -55,6 +62,7 @@ impl Aby3Query {
5562
}
5663

5764
pub type Aby3VectorRef = <Aby3Store as VectorStore>::VectorRef;
65+
pub type Aby3DistanceRef = <Aby3Store as VectorStore>::DistanceRef;
5866

5967
pub type Aby3SharedIrises = SharedIrises<ArcIris>;
6068
pub type Aby3SharedIrisesRef = SharedIrisesRef<ArcIris>;
@@ -122,6 +130,48 @@ impl Aby3Store {
122130
pub async fn checksum(&self) -> u64 {
123131
self.storage.checksum().await
124132
}
133+
134+
/// Obliviously swaps the elements in `list` at the given `indices` according to the `swap_bits`.
135+
/// If bit is 0, the elements are swapped, otherwise they are left unchanged.
136+
/// Note that unchanged elements of the list are propagated as secret-shares.
137+
pub async fn oblivious_swap_batch_plain_ids(
138+
&mut self,
139+
swap_bits: Vec<Share<Bit>>,
140+
list: &[(u32, Aby3DistanceRef)],
141+
indices: &[(usize, usize)],
142+
) -> Result<Vec<(Share<u32>, Aby3DistanceRef)>> {
143+
if list.is_empty() {
144+
return Ok(vec![]);
145+
}
146+
147+
conditionally_swap_distances_plain_ids(&mut self.session, swap_bits, list, indices).await
148+
}
149+
150+
/// Obliviously compares pairs of distances in batch and returns a secret shared bit a < b for each pair.
151+
pub async fn oblivious_less_than_batch(
152+
&mut self,
153+
distances: &[(Aby3DistanceRef, Aby3DistanceRef)],
154+
) -> Result<Vec<Share<Bit>>> {
155+
if distances.is_empty() {
156+
return Ok(vec![]);
157+
}
158+
oblivious_cross_compare(&mut self.session, distances).await
159+
}
160+
161+
/// Obliviously swaps the elements in `list` at the given `indices` according to the `swap_bits`.
162+
/// If bit is 0, the elements are swapped, otherwise they are left unchanged.
163+
pub async fn oblivious_swap_batch(
164+
&mut self,
165+
swap_bits: Vec<Share<Bit>>,
166+
list: &[(Share<u32>, Aby3DistanceRef)],
167+
indices: &[(usize, usize)],
168+
) -> Result<Vec<(Share<u32>, Aby3DistanceRef)>> {
169+
if list.is_empty() {
170+
return Ok(vec![]);
171+
}
172+
173+
conditionally_swap_distances(&mut self.session, swap_bits, list, indices).await
174+
}
125175
}
126176

127177
impl VectorStore for Aby3Store {
@@ -251,7 +301,7 @@ mod tests {
251301

252302
use super::*;
253303
use crate::{
254-
execution::hawk_main::scheduler::parallelize,
304+
execution::{hawk_main::scheduler::parallelize, session::SessionHandles},
255305
hawkers::{
256306
aby3::test_utils::{
257307
eval_vector_distance, get_owner_index, lazy_random_setup,
@@ -527,6 +577,91 @@ mod tests {
527577
Ok(())
528578
}
529579

580+
#[tokio::test(flavor = "multi_thread")]
581+
#[traced_test]
582+
async fn test_oblivious_swap() -> Result<()> {
583+
let list_len = 6_u32;
584+
let plain_list = (0..list_len)
585+
.map(|i| (VectorId::from_0_index(i), (i, i)))
586+
.collect_vec();
587+
let swap_bits_for_plain = vec![true, false];
588+
let indices_for_plain = vec![(0, 1), (4, 5)];
589+
let swap_bits_for_secret = vec![true, false, false];
590+
let indices_for_secret = vec![(1, 2), (0, 4), (3, 5)];
591+
592+
let mut local_stores = setup_local_store_aby3_players(NetworkType::Local).await?;
593+
let mut jobs = JoinSet::new();
594+
for store in local_stores.iter_mut() {
595+
let store = store.clone();
596+
let swap_bits_for_plain = swap_bits_for_plain.clone();
597+
let swap_bits_for_secret = swap_bits_for_secret.clone();
598+
let plain_list = plain_list.clone();
599+
let indices_for_plain = indices_for_plain.clone();
600+
let indices_for_secret = indices_for_secret.clone();
601+
jobs.spawn(async move {
602+
let mut store_lock = store.lock().await;
603+
let role = store_lock.session.own_role();
604+
let swap_bits1 = swap_bits_for_plain
605+
.iter()
606+
.map(|b| Share::from_const(Bit::new(*b), role))
607+
.collect_vec();
608+
let swap_bits2 = swap_bits_for_secret
609+
.iter()
610+
.map(|b| Share::from_const(Bit::new(*b), role))
611+
.collect_vec();
612+
let list = plain_list
613+
.iter()
614+
.map(|(v, d)| {
615+
(
616+
v.index(),
617+
DistanceShare::new(
618+
Share::from_const(d.0, role),
619+
Share::from_const(d.1, role),
620+
),
621+
)
622+
})
623+
.collect_vec();
624+
let tmp_list = store_lock
625+
.oblivious_swap_batch_plain_ids(swap_bits1, &list, &indices_for_plain)
626+
.await?;
627+
store_lock
628+
.oblivious_swap_batch(swap_bits2, &tmp_list, &indices_for_secret)
629+
.await
630+
});
631+
}
632+
let res = jobs
633+
.join_all()
634+
.await
635+
.into_iter()
636+
.collect::<Result<Vec<_>>>()?;
637+
let mut expected_list = plain_list.clone();
638+
expected_list.swap(4, 5);
639+
expected_list.swap(0, 4);
640+
expected_list.swap(3, 5);
641+
642+
for (i, exp) in expected_list.iter().enumerate() {
643+
let id = (res[0][i].clone().0 + &res[1][i].0 + &res[2][i].0)
644+
.get_a()
645+
.convert();
646+
assert_eq!(id, exp.0.index());
647+
648+
let distance = {
649+
let code_dot =
650+
(res[0][i].clone().1.code_dot + &res[1][i].1.code_dot + &res[2][i].1.code_dot)
651+
.get_a()
652+
.convert();
653+
let mask_dot =
654+
(res[0][i].clone().1.mask_dot + &res[1][i].1.mask_dot + &res[2][i].1.mask_dot)
655+
.get_a()
656+
.convert();
657+
(code_dot, mask_dot)
658+
};
659+
assert_eq!(distance, exp.1);
660+
}
661+
662+
Ok(())
663+
}
664+
530665
#[tokio::test(flavor = "multi_thread")]
531666
#[traced_test]
532667
async fn test_gr_aby3_store_plaintext_batch() -> Result<()> {

iris-mpc-cpu/src/hnsw/sorting/swap_network.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
use crate::hnsw::VectorStore;
1+
use crate::{
2+
hawkers::aby3::aby3_store::{Aby3DistanceRef, Aby3Store},
3+
hnsw::VectorStore,
4+
shares::Share,
5+
};
26
use eyre::{eyre, Result};
37
use itertools::Itertools;
48

@@ -142,3 +146,46 @@ pub async fn apply_swap_network<V: VectorStore>(
142146

143147
Ok(())
144148
}
149+
150+
/// Function obliviously applies the supplied swap network `network` to the list of
151+
/// tuples containing ids and distances between iris vectors as `(u32, Aby3DistanceRef)`.
152+
/// An 'Aby3Store' object executes comparisons via MPC for each layer in batches.
153+
/// Note that output is secret-shared, even for unchanged elements of the list.
154+
/// This implies that all vector ids are considered secret-shared after the first layer,
155+
/// which might introduce an additional throughput overhead.
156+
/// For example, for a swap network implementing the tournament method to find the minimum of a list of length N,
157+
/// this throughput overhead is O(1).
158+
pub async fn apply_oblivious_swap_network(
159+
store: &mut Aby3Store,
160+
list: &[(u32, Aby3DistanceRef)],
161+
network: &SwapNetwork,
162+
) -> Result<Vec<(Share<u32>, Aby3DistanceRef)>> {
163+
let mut encrypted_list = vec![];
164+
for (layer_id, layer) in network.layers.iter().enumerate() {
165+
let distances: Vec<_> = layer
166+
.iter()
167+
.filter_map(
168+
|(idx1, idx2): &(usize, usize)| match (list.get(*idx1), list.get(*idx2)) {
169+
// swap order to check for strict inequality d1 > d2
170+
(Some((_, d1)), Some((_, d2))) => Some((d2.clone(), d1.clone())),
171+
_ => None,
172+
},
173+
)
174+
.collect();
175+
// Computes d1 > d2 without opening as in less_than_batch
176+
let comp_results = store.oblivious_less_than_batch(&distances).await?;
177+
encrypted_list = if layer_id == 0 {
178+
// First layer: input ids are in plaintext, so we can use the more efficient plain_ids version.
179+
store
180+
.oblivious_swap_batch_plain_ids(comp_results, list, layer)
181+
.await?
182+
} else {
183+
// Following layers: input ids are secret shared
184+
store
185+
.oblivious_swap_batch(comp_results, &encrypted_list, layer)
186+
.await?
187+
};
188+
}
189+
190+
Ok(encrypted_list)
191+
}

0 commit comments

Comments
 (0)