diff --git a/iris-mpc-common/src/iris_db/iris.rs b/iris-mpc-common/src/iris_db/iris.rs index 277045f72..4026528a4 100644 --- a/iris-mpc-common/src/iris_db/iris.rs +++ b/iris-mpc-common/src/iris_db/iris.rs @@ -279,6 +279,20 @@ impl IrisCode { (code_distance as u16, combined_mask_len as u16) } + /// Return the minimum distance of an iris code against all rotations of another iris code. + pub fn get_min_distance_fraction(&self, other: &Self) -> (u16, u16) { + let mut min_distance = (u16::MAX, u16::MAX); + for rotation in other.all_rotations() { + let distance = rotation.get_distance_fraction(self); + if distance.0 as u32 * (min_distance.1 as u32) + < distance.1 as u32 * min_distance.0 as u32 + { + min_distance = distance; + } + } + min_distance + } + /// Return the fractional Hamming distance between two iris codes, represented /// as the `i16` dot product of associated masked-bit vectors and the `u16` size /// of the common unmasked region. diff --git a/iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs b/iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs index 5109067be..6740bdd19 100644 --- a/iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs +++ b/iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs @@ -16,7 +16,7 @@ use crate::{ }, }; use eyre::Result; -use iris_mpc_common::vector_id::VectorId; +use iris_mpc_common::{vector_id::VectorId, ROTATIONS}; use itertools::Itertools; use std::{collections::HashMap, fmt::Debug, sync::Arc, vec}; use tracing::instrument; @@ -198,6 +198,25 @@ impl Aby3Store { } Ok(res[0].clone()) } + + #[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = vectors.len()))] + async fn eval_rotation_distances_batch( + &mut self, + query: &Aby3Query, + vectors: &[Aby3VectorRef], + ) -> Result> { + if vectors.is_empty() { + return Ok(vec![]); + } + + let ds_and_ts = self + .workers + .rotation_aware_dot_product_batch(query.iris_proc.clone(), vectors.to_vec()) + .await?; + + let dist = galois_ring_to_rep3(&mut self.session, ds_and_ts).await?; + self.lift_distances(dist).await + } } impl VectorStore for Aby3Store { @@ -274,6 +293,25 @@ impl VectorStore for Aby3Store { self.lift_distances(dist).await } + #[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = vectors.len()))] + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result> { + if vectors.is_empty() { + return Ok(vec![]); + } + + let distances = self.eval_rotation_distances_batch(query, vectors).await?; + let mut results = Vec::with_capacity(vectors.len()); + for rot_dists in distances.chunks(ROTATIONS) { + let min_dist = self.oblivious_min_distance(rot_dists).await?; + results.push(min_dist); + } + Ok(results) + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(lte_threshold_and_open(&mut self.session, &[distance.clone()]).await?[0]) } @@ -763,10 +801,10 @@ mod tests { // compute distances in plaintext let dist1_plain = plaintext_store - .eval_distance_batch(&Arc::new(plaintext_database[0].clone()), &plaintext_inserts) + .eval_minimal_rotation_distance_batch(&plaintext_preps[0], &plaintext_inserts) .await?; let dist2_plain = plaintext_store - .eval_distance_batch(&Arc::new(plaintext_database[1].clone()), &plaintext_inserts) + .eval_minimal_rotation_distance_batch(&plaintext_preps[1], &plaintext_inserts) .await?; let dist_plain = dist1_plain .into_iter() @@ -799,10 +837,10 @@ mod tests { jobs.spawn(async move { let mut store_lock = store.lock().await; let dist1_aby3 = store_lock - .eval_distance_batch(&player_preps[0].clone(), &player_inserts) + .eval_minimal_rotation_distance_batch(&player_preps[0], &player_inserts) .await?; let dist2_aby3 = store_lock - .eval_distance_batch(&player_preps[1].clone(), &player_inserts) + .eval_minimal_rotation_distance_batch(&player_preps[1], &player_inserts) .await?; let dist_aby3 = dist1_aby3 .into_iter() diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 527cde810..79bb777c2 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -156,6 +156,27 @@ impl VectorStore for PlaintextStore { Ok(query.get_distance_fraction(vector_code)) } + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result> { + debug!(event_type = EvaluateDistance.id()); + let vector_codes = vectors + .iter() + .map(|v| { + let serial_id = v.serial_id(); + self.storage.get_vector(v).ok_or_else(|| { + eyre::eyre!("Vector ID not found in store for serial {}", serial_id) + }) + }) + .collect::>>()?; + Ok(vector_codes + .into_iter() + .map(|v| v.get_min_distance_fraction(query)) + .collect()) + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(fraction_is_match(distance)) } @@ -255,6 +276,28 @@ impl VectorStore for SharedPlaintextStore { Ok(query.get_distance_fraction(vector_code)) } + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result> { + debug!(event_type = EvaluateDistance.id()); + let store = self.storage.read().await; + let vector_codes = vectors + .iter() + .map(|v| { + let serial_id = v.serial_id(); + store.get_vector(v).ok_or_else(|| { + eyre::eyre!("Vector ID not found in store for serial {}", serial_id) + }) + }) + .collect::>>()?; + Ok(vector_codes + .into_iter() + .map(|v| v.get_min_distance_fraction(query)) + .collect()) + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(fraction_is_match(distance)) } diff --git a/iris-mpc-cpu/src/hnsw/graph/layered_graph.rs b/iris-mpc-cpu/src/hnsw/graph/layered_graph.rs index 9c12389ad..26fd00869 100644 --- a/iris-mpc-cpu/src/hnsw/graph/layered_graph.rs +++ b/iris-mpc-cpu/src/hnsw/graph/layered_graph.rs @@ -295,6 +295,14 @@ mod tests { Ok(hamming_distance(vector_0, vector_1)) } + async fn eval_minimal_rotation_distance_batch( + &mut self, + _query: &Self::QueryRef, + _vectors: &[Self::VectorRef], + ) -> Result> { + unimplemented!() + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(*distance == 0) } diff --git a/iris-mpc-cpu/src/hnsw/vector_store.rs b/iris-mpc-cpu/src/hnsw/vector_store.rs index dbf4a251a..1069dfcf0 100644 --- a/iris-mpc-cpu/src/hnsw/vector_store.rs +++ b/iris-mpc-cpu/src/hnsw/vector_store.rs @@ -97,6 +97,14 @@ pub trait VectorStore: Debug { Ok(results) } + /// Evaluate the minimal distance over all distances between rotations of the query and a vector in the input batch. + /// TODO: replace eval_distance_batch with this method when API is stable. + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result>; + /// Check whether a batch of distances are matches. /// The default implementation is a loop over `is_match`. /// Override for more efficient batch match checks. diff --git a/iris-mpc-utils/src/graphs/plaintext/store.rs b/iris-mpc-utils/src/graphs/plaintext/store.rs index 405fcfb7f..d94835611 100644 --- a/iris-mpc-utils/src/graphs/plaintext/store.rs +++ b/iris-mpc-utils/src/graphs/plaintext/store.rs @@ -148,6 +148,27 @@ impl VectorStore for PlaintextStore { Ok(query.get_distance_fraction(vector_code)) } + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result> { + debug!(event_type = EvaluateDistance.id()); + let vector_codes = vectors + .iter() + .map(|v| { + let serial_id = v.serial_id(); + self.storage.get_vector(v).ok_or_else(|| { + eyre::eyre!("Vector ID not found in store for serial {}", serial_id) + }) + }) + .collect::>>()?; + Ok(vector_codes + .into_iter() + .map(|v| v.get_min_distance_fraction(query)) + .collect()) + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(fraction_is_match(distance)) } @@ -247,6 +268,28 @@ impl VectorStore for SharedPlaintextStore { Ok(query.get_distance_fraction(vector_code)) } + async fn eval_minimal_rotation_distance_batch( + &mut self, + query: &Self::QueryRef, + vectors: &[Self::VectorRef], + ) -> Result> { + debug!(event_type = EvaluateDistance.id()); + let store = self.storage.read().await; + let vector_codes = vectors + .iter() + .map(|v| { + let serial_id = v.serial_id(); + store.get_vector(v).ok_or_else(|| { + eyre::eyre!("Vector ID not found in store for serial {}", serial_id) + }) + }) + .collect::>>()?; + Ok(vector_codes + .into_iter() + .map(|v| v.get_min_distance_fraction(query)) + .collect()) + } + async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result { Ok(fraction_is_match(distance)) }