Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
};
use eyre::Result;
use iris_mpc_common::vector_id::VectorId;
use itertools::Itertools;
use itertools::{izip, Itertools};
use std::{collections::HashMap, fmt::Debug, sync::Arc, vec};
use tracing::instrument;

Expand Down Expand Up @@ -198,6 +198,44 @@
}
Ok(res[0].clone())
}

/// Obliviously computes the minimum distance for each batch of given distances of the same size.
/// The inner vector distances[i] contains the ith distances of each batch.

Check warning on line 203 in iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs

View workflow job for this annotation

GitHub Actions / doc

unresolved link to `i`
pub async fn oblivious_min_distance_batch(
&mut self,
distances: &[Vec<Aby3DistanceRef>],
) -> Result<Vec<Aby3DistanceRef>> {
if distances.is_empty() {
eyre::bail!("Cannot compute minimum of empty list");
}
let len = distances[0].len();
for (i, d) in distances.iter().enumerate() {
if d.len() != len {
eyre::bail!("All distance lists must have the same length. List at index {} has length {}, while the first list has length {}", i, d.len(), len);
}
}
let mut res = distances.to_vec();
while res.len() > 1 {
// if the length is odd, we save the last distance to add it back later
let maybe_last_distance = if res.len() % 2 == 1 { res.pop() } else { None };
// create pairs from the remaining distances
let pairs = res
.chunks_exact(2)
.flat_map(|chunk| izip!(chunk[0].clone(), chunk[1].clone()).collect_vec())
.collect_vec();
// compute minimums of pairs
let flattened_res = min_of_pair_batch(&mut self.session, &pairs).await?;
res = flattened_res
.chunks_exact(len)
.map(|chunk| chunk.to_vec())
.collect_vec();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let pairs = res
.chunks_exact(2)
.flat_map(|chunk| izip!(chunk[0].clone(), chunk[1].clone()).collect_vec())
.collect_vec();
// compute minimums of pairs
let flattened_res = min_of_pair_batch(&mut self.session, &pairs).await?;
res = flattened_res
.chunks_exact(len)
.map(|chunk| chunk.to_vec())
.collect_vec();
let mut res1 = vec![];
std::mem::swap(&mut res1, &mut res);
let pairs: Vec<(_, _)> = res1
.into_iter()
.tuples()
.flat_map(|(a, b)| izip!(a, b).collect_vec())
.collect();
// compute minimums of pairs
let flattened_res = min_of_pair_batch(&mut self.session, &pairs).await?;
res = flattened_res
.into_iter()
.chunks(len)
.into_iter()
.map(|chunk| chunk.collect())
.collect_vec();

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks!

// if we saved a last distance, we need to add it back
if let Some(last_distance) = maybe_last_distance {
res.push(last_distance.clone());
}
}
Ok(res[0].clone())
}
}

impl VectorStore for Aby3Store {
Expand Down Expand Up @@ -740,6 +778,92 @@
Ok(())
}

#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_oblivious_min_batch() -> Result<()> {
let list_len = 6_u32;
let num_lists = 3;
// create 3 lists of length 6
// [[(1,1), (2,1), (3,1), (4,1), (6,1), (5,1)],
// [(7,1), (8,1), (9,1), (12,1), (10,1), (11,1)],
// [(13,1), (14,1), (18,1), (15,1), (16,1), (17,1)]]
let mut flat_list = (1..=(list_len * num_lists)).map(|i| (i, 1)).collect_vec();
flat_list.swap(5, 4);
flat_list.swap(11, 9);
flat_list.swap(17, 14);
// [(1,1), (7,1), (13,1)],
// [(2,1), (8,1), (14,1)],
// [(3,1), (9,1), (18,1)],
// [(4,1), (12,1), (15,1)],
// [(6,1), (10,1), (16,1)],
// [(5,1), (11,1), (17,1)]
let mut plain_list = Vec::with_capacity(list_len as usize);
for i in 0..list_len {
let mut slice = Vec::with_capacity(num_lists as usize);
for j in 0..num_lists {
slice.push(flat_list[(i + list_len * j) as usize]);
}
plain_list.push(slice);
}

let mut local_stores = setup_local_store_aby3_players(NetworkType::Local).await?;
let mut jobs = JoinSet::new();
for store in local_stores.iter_mut() {
let store = store.clone();
let plain_list = plain_list.clone();
jobs.spawn(async move {
let mut store_lock = store.lock().await;
let role = store_lock.session.own_role();
let list = plain_list
.iter()
.map(|sub_list| {
sub_list
.iter()
.map(|(code_dist, mask_dist)| {
DistanceShare::new(
Share::from_const(*code_dist, role),
Share::from_const(*mask_dist, role),
)
})
.collect_vec()
})
.collect_vec();
store_lock.oblivious_min_distance_batch(&list).await
});
}
let res = jobs
.join_all()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
let expected = flat_list
.chunks_exact(list_len as usize)
.map(|sublist| {
sublist
.iter()
.min_by(|a, b| (b.0 * a.1).cmp(&(a.0 * b.1)))
.unwrap()
})
.collect_vec();

for (i, exp) in expected.into_iter().enumerate() {
let distance = {
let code_dot =
(res[0][i].clone().code_dot + &res[1][i].code_dot + &res[2][i].code_dot)
.get_a()
.convert();
let mask_dot =
(res[0][i].clone().mask_dot + &res[1][i].mask_dot + &res[2][i].mask_dot)
.get_a()
.convert();
(code_dot, mask_dot)
};
assert_eq!(distance, *exp);
}

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
#[traced_test]
async fn test_gr_aby3_store_plaintext_batch() -> Result<()> {
Expand Down
Loading