Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions iris-mpc-common/src/iris_db/iris.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,20 @@ impl IrisCode {
(code_distance as u16, combined_mask_len as u16)
}

/// Return the minimum distance of an iris code against all rotations of another iris code.
pub fn get_min_distance_fraction(&self, other: &Self) -> (u16, u16) {
let mut min_distance = (u16::MAX, u16::MAX);
for rotation in other.all_rotations() {
let distance = rotation.get_distance_fraction(self);
if distance.0 as u32 * (min_distance.1 as u32)
< distance.1 as u32 * min_distance.0 as u32
{
min_distance = distance;
}
}
min_distance
}

/// Return the fractional Hamming distance between two iris codes, represented
/// as the `i16` dot product of associated masked-bit vectors and the `u16` size
/// of the common unmasked region.
Expand Down
48 changes: 43 additions & 5 deletions iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
},
};
use eyre::Result;
use iris_mpc_common::vector_id::VectorId;
use iris_mpc_common::{vector_id::VectorId, ROTATIONS};
use itertools::Itertools;
use std::{collections::HashMap, fmt::Debug, sync::Arc, vec};
use tracing::instrument;
Expand Down Expand Up @@ -198,6 +198,25 @@ impl Aby3Store {
}
Ok(res[0].clone())
}

#[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = vectors.len()))]
async fn eval_rotation_distances_batch(
&mut self,
query: &Aby3Query,
vectors: &[Aby3VectorRef],
) -> Result<Vec<Aby3DistanceRef>> {
if vectors.is_empty() {
return Ok(vec![]);
}

let ds_and_ts = self
.workers
.rotation_aware_dot_product_batch(query.iris_proc.clone(), vectors.to_vec())
.await?;

let dist = galois_ring_to_rep3(&mut self.session, ds_and_ts).await?;
self.lift_distances(dist).await
}
}

impl VectorStore for Aby3Store {
Expand Down Expand Up @@ -274,6 +293,25 @@ impl VectorStore for Aby3Store {
self.lift_distances(dist).await
}

#[instrument(level = "trace", target = "searcher::network", skip_all, fields(batch_size = vectors.len()))]
async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
if vectors.is_empty() {
return Ok(vec![]);
}

let distances = self.eval_rotation_distances_batch(query, vectors).await?;
let mut results = Vec::with_capacity(vectors.len());
for rot_dists in distances.chunks(ROTATIONS) {
let min_dist = self.oblivious_min_distance(rot_dists).await?;
results.push(min_dist);
}
Ok(results)
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(lte_threshold_and_open(&mut self.session, &[distance.clone()]).await?[0])
}
Expand Down Expand Up @@ -763,10 +801,10 @@ mod tests {

// compute distances in plaintext
let dist1_plain = plaintext_store
.eval_distance_batch(&Arc::new(plaintext_database[0].clone()), &plaintext_inserts)
.eval_minimal_rotation_distance_batch(&plaintext_preps[0], &plaintext_inserts)
.await?;
let dist2_plain = plaintext_store
.eval_distance_batch(&Arc::new(plaintext_database[1].clone()), &plaintext_inserts)
.eval_minimal_rotation_distance_batch(&plaintext_preps[1], &plaintext_inserts)
.await?;
let dist_plain = dist1_plain
.into_iter()
Expand Down Expand Up @@ -799,10 +837,10 @@ mod tests {
jobs.spawn(async move {
let mut store_lock = store.lock().await;
let dist1_aby3 = store_lock
.eval_distance_batch(&player_preps[0].clone(), &player_inserts)
.eval_minimal_rotation_distance_batch(&player_preps[0], &player_inserts)
.await?;
let dist2_aby3 = store_lock
.eval_distance_batch(&player_preps[1].clone(), &player_inserts)
.eval_minimal_rotation_distance_batch(&player_preps[1], &player_inserts)
.await?;
let dist_aby3 = dist1_aby3
.into_iter()
Expand Down
43 changes: 43 additions & 0 deletions iris-mpc-cpu/src/hawkers/plaintext_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ impl VectorStore for PlaintextStore {
Ok(query.get_distance_fraction(vector_code))
}

async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
debug!(event_type = EvaluateDistance.id());
let vector_codes = vectors
.iter()
.map(|v| {
let serial_id = v.serial_id();
self.storage.get_vector(v).ok_or_else(|| {
eyre::eyre!("Vector ID not found in store for serial {}", serial_id)
})
})
.collect::<Result<Vec<_>>>()?;
Ok(vector_codes
.into_iter()
.map(|v| v.get_min_distance_fraction(query))
.collect())
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(fraction_is_match(distance))
}
Expand Down Expand Up @@ -255,6 +276,28 @@ impl VectorStore for SharedPlaintextStore {
Ok(query.get_distance_fraction(vector_code))
}

async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
debug!(event_type = EvaluateDistance.id());
let store = self.storage.read().await;
let vector_codes = vectors
.iter()
.map(|v| {
let serial_id = v.serial_id();
store.get_vector(v).ok_or_else(|| {
eyre::eyre!("Vector ID not found in store for serial {}", serial_id)
})
})
.collect::<Result<Vec<_>>>()?;
Ok(vector_codes
.into_iter()
.map(|v| v.get_min_distance_fraction(query))
.collect())
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(fraction_is_match(distance))
}
Expand Down
8 changes: 8 additions & 0 deletions iris-mpc-cpu/src/hnsw/graph/layered_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,14 @@ mod tests {
Ok(hamming_distance(vector_0, vector_1))
}

async fn eval_minimal_rotation_distance_batch(
&mut self,
_query: &Self::QueryRef,
_vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
unimplemented!()
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(*distance == 0)
}
Expand Down
8 changes: 8 additions & 0 deletions iris-mpc-cpu/src/hnsw/vector_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ pub trait VectorStore: Debug {
Ok(results)
}

/// Evaluate the minimal distance over all distances between rotations of the query and a vector in the input batch.
/// TODO: replace eval_distance_batch with this method when API is stable.
async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>>;

/// Check whether a batch of distances are matches.
/// The default implementation is a loop over `is_match`.
/// Override for more efficient batch match checks.
Expand Down
43 changes: 43 additions & 0 deletions iris-mpc-utils/src/graphs/plaintext/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,27 @@ impl VectorStore for PlaintextStore {
Ok(query.get_distance_fraction(vector_code))
}

async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
debug!(event_type = EvaluateDistance.id());
let vector_codes = vectors
.iter()
.map(|v| {
let serial_id = v.serial_id();
self.storage.get_vector(v).ok_or_else(|| {
eyre::eyre!("Vector ID not found in store for serial {}", serial_id)
})
})
.collect::<Result<Vec<_>>>()?;
Ok(vector_codes
.into_iter()
.map(|v| v.get_min_distance_fraction(query))
.collect())
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(fraction_is_match(distance))
}
Expand Down Expand Up @@ -247,6 +268,28 @@ impl VectorStore for SharedPlaintextStore {
Ok(query.get_distance_fraction(vector_code))
}

async fn eval_minimal_rotation_distance_batch(
&mut self,
query: &Self::QueryRef,
vectors: &[Self::VectorRef],
) -> Result<Vec<Self::DistanceRef>> {
debug!(event_type = EvaluateDistance.id());
let store = self.storage.read().await;
let vector_codes = vectors
.iter()
.map(|v| {
let serial_id = v.serial_id();
store.get_vector(v).ok_or_else(|| {
eyre::eyre!("Vector ID not found in store for serial {}", serial_id)
})
})
.collect::<Result<Vec<_>>>()?;
Ok(vector_codes
.into_iter()
.map(|v| v.get_min_distance_fraction(query))
.collect())
}

async fn is_match(&mut self, distance: &Self::DistanceRef) -> Result<bool> {
Ok(fraction_is_match(distance))
}
Expand Down
Loading