Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 142 additions & 7 deletions iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
};
use eyre::Result;
use iris_mpc_common::{vector_id::VectorId, ROTATIONS};
use itertools::Itertools;
use itertools::{izip, Itertools};
use std::{collections::HashMap, fmt::Debug, sync::Arc, vec};
use tracing::instrument;

Expand Down Expand Up @@ -199,6 +199,49 @@
Ok(res[0].clone())
}

/// Obliviously computes the minimum distance for each batch of given distances of the same size.

Check warning on line 202 in iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs

View workflow job for this annotation

GitHub Actions / doc

unresolved link to `i`
/// The inner vector distances[i] contains the ith distances of each batch.
#[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = distances.len()))]
pub async fn oblivious_min_distance_batch(
&mut self,
distances: &[Vec<Aby3DistanceRef>],
) -> Result<Vec<Aby3DistanceRef>> {
if distances.is_empty() {
eyre::bail!("Cannot compute minimum of empty list");
}
let len = distances[0].len();
for (i, d) in distances.iter().enumerate() {
if d.len() != len {
eyre::bail!("All distance lists must have the same length. List at index {} has length {}, while the first list has length {}", i, d.len(), len);
}
}
let mut res = distances.to_vec();
while res.len() > 1 {
// if the length is odd, we save the last distance to add it back later
let maybe_last_distance = if res.len() % 2 == 1 { res.pop() } else { None };
let mut res1 = vec![];
std::mem::swap(&mut res1, &mut res);
let pairs: Vec<(_, _)> = res1
.into_iter()
.tuples()
.flat_map(|(a, b)| izip!(a, b).collect_vec())
.collect();
// compute minimums of pairs
let flattened_res = min_of_pair_batch(&mut self.session, &pairs).await?;
res = flattened_res
.into_iter()
.chunks(len)
.into_iter()
.map(|chunk| chunk.collect())
.collect_vec();
// if we saved a last distance, we need to add it back
if let Some(last_distance) = maybe_last_distance {
res.push(last_distance.clone());
}
}
Ok(res[0].clone())
}

#[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = vectors.len()))]
async fn eval_rotation_distances_batch(
&mut self,
Expand Down Expand Up @@ -304,12 +347,18 @@
}

let distances = self.eval_rotation_distances_batch(query, vectors).await?;
let mut results = Vec::with_capacity(vectors.len());
for rot_dists in distances.chunks(ROTATIONS) {
let min_dist = self.oblivious_min_distance(rot_dists).await?;
results.push(min_dist);
}
Ok(results)
let distance_per_rotation: Vec<Vec<Aby3DistanceRef>> = (0..ROTATIONS)
.map(|i| {
distances
.iter()
.skip(i)
.step_by(ROTATIONS)
.cloned()
.collect()
})
.collect();
self.oblivious_min_distance_batch(&distance_per_rotation)
.await
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Expand Down Expand Up @@ -778,6 +827,92 @@
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_oblivious_min_batch() -> Result<()> {
let list_len = 6_u32;
let num_lists = 3;
// create 3 lists of length 6
// [[(1,1), (2,1), (3,1), (4,1), (6,1), (5,1)],
// [(7,1), (8,1), (9,1), (12,1), (10,1), (11,1)],
// [(13,1), (14,1), (18,1), (15,1), (16,1), (17,1)]]
let mut flat_list = (1..=(list_len * num_lists)).map(|i| (i, 1)).collect_vec();
flat_list.swap(5, 4);
flat_list.swap(11, 9);
flat_list.swap(17, 14);
// [(1,1), (7,1), (13,1)],
// [(2,1), (8,1), (14,1)],
// [(3,1), (9,1), (18,1)],
// [(4,1), (12,1), (15,1)],
// [(6,1), (10,1), (16,1)],
// [(5,1), (11,1), (17,1)]
let mut plain_list = Vec::with_capacity(list_len as usize);
for i in 0..list_len {
let mut slice = Vec::with_capacity(num_lists as usize);
for j in 0..num_lists {
slice.push(flat_list[(i + list_len * j) as usize]);
}
plain_list.push(slice);
}

let mut local_stores = setup_local_store_aby3_players(NetworkType::Local).await?;
let mut jobs = JoinSet::new();
for store in local_stores.iter_mut() {
let store = store.clone();
let plain_list = plain_list.clone();
jobs.spawn(async move {
let mut store_lock = store.lock().await;
let role = store_lock.session.own_role();
let list = plain_list
.iter()
.map(|sub_list| {
sub_list
.iter()
.map(|(code_dist, mask_dist)| {
DistanceShare::new(
Share::from_const(*code_dist, role),
Share::from_const(*mask_dist, role),
)
})
.collect_vec()
})
.collect_vec();
store_lock.oblivious_min_distance_batch(&list).await
});
}
let res = jobs
.join_all()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
let expected = flat_list
.chunks_exact(list_len as usize)
.map(|sublist| {
sublist
.iter()
.min_by(|a, b| (b.0 * a.1).cmp(&(a.0 * b.1)))
.unwrap()
})
.collect_vec();

for (i, exp) in expected.into_iter().enumerate() {
let distance = {
let code_dot =
(res[0][i].clone().code_dot + &res[1][i].code_dot + &res[2][i].code_dot)
.get_a()
.convert();
let mask_dot =
(res[0][i].clone().mask_dot + &res[1][i].mask_dot + &res[2][i].mask_dot)
.get_a()
.convert();
(code_dot, mask_dot)
};
assert_eq!(distance, *exp);
}

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_gr_aby3_store_plaintext_batch() -> Result<()> {
Expand Down
Loading