From 3021b5135acaba0bedf27d116151b030ec127adc Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Fri, 19 Sep 2025 16:32:37 +0200 Subject: [PATCH 01/16] wip: anon stats separately for reauth and uniqueness --- iris-mpc-common/src/config/mod.rs | 11 + iris-mpc-common/src/helpers/statistics.rs | 49 +++ iris-mpc-common/src/job.rs | 4 + iris-mpc-common/tests/statistics.rs | 3 + iris-mpc-cpu/src/execution/hawk_main.rs | 3 + iris-mpc-gpu/src/server/actor.rs | 455 +++++++++++++++++----- iris-mpc-gpu/tests/e2e-anon-stats.rs | 4 + iris-mpc-gpu/tests/e2e.rs | 4 + iris-mpc/bin/server.rs | 69 ++++ iris-mpc/src/services/processors/job.rs | 31 ++ 10 files changed, 545 insertions(+), 88 deletions(-) diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index ce312ddda..d46405ab4 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -191,6 +191,10 @@ pub struct Config { #[serde(default = "default_match_distances_2d_buffer_size")] pub match_distances_2d_buffer_size: usize, + /// Minimum number of reauth match distances required before publishing 1D anonymized stats + #[serde(default = "default_reauth_match_distances_min_count")] + pub reauth_match_distances_min_count: usize, + #[serde(default = "default_n_buckets")] pub n_buckets: usize, @@ -384,6 +388,10 @@ fn default_match_distances_2d_buffer_size() -> usize { 1 << 13 // 8192 } +fn default_reauth_match_distances_min_count() -> usize { + 10_000 +} + fn default_n_buckets() -> usize { 375 } @@ -697,6 +705,7 @@ pub struct CommonConfig { match_distances_buffer_size: usize, match_distances_buffer_size_extra_percent: usize, match_distances_2d_buffer_size: usize, + reauth_match_distances_min_count: usize, n_buckets: usize, enable_sending_anonymized_stats_message: bool, enable_sending_mirror_anonymized_stats_message: bool, @@ -778,6 +787,7 @@ impl From for CommonConfig { match_distances_buffer_size, match_distances_buffer_size_extra_percent, match_distances_2d_buffer_size, + reauth_match_distances_min_count, n_buckets, enable_sending_anonymized_stats_message, enable_sending_mirror_anonymized_stats_message, @@ -847,6 +857,7 @@ impl From for CommonConfig { match_distances_buffer_size, match_distances_buffer_size_extra_percent, match_distances_2d_buffer_size, + reauth_match_distances_min_count, n_buckets, enable_sending_anonymized_stats_message, enable_sending_mirror_anonymized_stats_message, diff --git a/iris-mpc-common/src/helpers/statistics.rs b/iris-mpc-common/src/helpers/statistics.rs index 6bc454941..0c258a779 100644 --- a/iris-mpc-common/src/helpers/statistics.rs +++ b/iris-mpc-common/src/helpers/statistics.rs @@ -6,6 +6,14 @@ use chrono::{ use serde::{Deserialize, Serialize}; use std::fmt; +// Operation of the anonymized statistics producer +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +pub enum Operation { + #[default] + Uniqueness, + Reauth, +} + // 1D anonymized statistics types #[derive(Debug, Clone, Serialize, Deserialize)] pub struct BucketResult { @@ -30,6 +38,8 @@ pub struct BucketStatistics { pub match_distances_buffer_size: usize, pub party_id: usize, pub eye: Eye, + // Operation type this histogram belongs to + pub operation: Operation, #[serde(with = "ts_seconds")] // Start timestamp at which we start recording the statistics pub start_time_utc_timestamp: DateTime, @@ -54,6 +64,7 @@ impl fmt::Display for BucketStatistics { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, " party_id: {}", self.party_id)?; writeln!(f, " eye: {:?}", self.eye)?; + writeln!(f, " operation: {:?}", self.operation)?; writeln!(f, " start_time_utc: {}", self.start_time_utc_timestamp)?; match &self.end_time_utc_timestamp { Some(end) => writeln!(f, " end_time_utc: {}", end)?, @@ -84,6 +95,7 @@ impl BucketStatistics { eye, match_distances_buffer_size, party_id, + operation: Operation::Uniqueness, start_time_utc_timestamp: Utc::now(), end_time_utc_timestamp: None, next_start_time_utc_timestamp: None, @@ -91,6 +103,24 @@ impl BucketStatistics { } } + /// Create a new `BucketStatistics` with explicit operation type. + pub fn new_with_operation( + match_distances_buffer_size: usize, + n_buckets: usize, + party_id: usize, + eye: Eye, + operation: Operation, + ) -> Self { + let mut bs = Self::new( + match_distances_buffer_size, + n_buckets, + party_id, + eye, + ); + bs.operation = operation; + bs + } + /// `buckets_array` array of buckets /// `buckets`, which for i=0..n_buckets might be a cumulative count (or /// partial sum). @@ -181,6 +211,8 @@ pub struct BucketStatistics2D { // The number of two-sided matches gathered before sending the statistics pub match_distances_buffer_size: usize, pub party_id: usize, + // Operation type this histogram belongs to + pub operation: Operation, #[serde(with = "ts_seconds")] pub start_time_utc_timestamp: DateTime, #[serde(with = "ts_seconds_option")] @@ -200,6 +232,7 @@ impl BucketStatistics2D { impl fmt::Display for BucketStatistics2D { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, " party_id: {}", self.party_id)?; + writeln!(f, " operation: {:?}", self.operation)?; writeln!(f, " start_time_utc: {}", self.start_time_utc_timestamp)?; match &self.end_time_utc_timestamp { Some(end) => writeln!(f, " end_time_utc: {}", end)?, @@ -231,12 +264,28 @@ impl BucketStatistics2D { n_buckets_per_side, match_distances_buffer_size, party_id, + operation: Operation::Uniqueness, start_time_utc_timestamp: Utc::now(), end_time_utc_timestamp: None, next_start_time_utc_timestamp: None, } } + pub fn new_with_operation( + match_distances_buffer_size: usize, + n_buckets_per_side: usize, + party_id: usize, + operation: Operation, + ) -> Self { + let mut bs = Self::new( + match_distances_buffer_size, + n_buckets_per_side, + party_id, + ); + bs.operation = operation; + bs + } + // Fill bucket counts for the 2D histogram. // buckets_2d is expected in row-major order (left index major): // buckets_2d[left_idx * n_buckets_per_side + right_idx] diff --git a/iris-mpc-common/src/job.rs b/iris-mpc-common/src/job.rs index 96a711fd4..0e47db2e9 100644 --- a/iris-mpc-common/src/job.rs +++ b/iris-mpc-common/src/job.rs @@ -410,6 +410,10 @@ pub struct ServerJobResult { // 2D anonymized statistics across both eyes (only for matches on both sides) // Only for Normal orientation pub anonymized_bucket_statistics_2d: BucketStatistics2D, + // Reauth-only anonymized stats (Normal orientation only) + pub anonymized_bucket_statistics_left_reauth: BucketStatistics, + pub anonymized_bucket_statistics_right_reauth: BucketStatistics, + pub anonymized_bucket_statistics_2d_reauth: BucketStatistics2D, // Mirror orientation bucket statistics pub anonymized_bucket_statistics_left_mirror: BucketStatistics, pub anonymized_bucket_statistics_right_mirror: BucketStatistics, diff --git a/iris-mpc-common/tests/statistics.rs b/iris-mpc-common/tests/statistics.rs index f6be24fda..a511a9a9e 100644 --- a/iris-mpc-common/tests/statistics.rs +++ b/iris-mpc-common/tests/statistics.rs @@ -1,5 +1,6 @@ mod tests { use chrono::{TimeZone, Utc}; + use iris_mpc_common::helpers::statistics::Operation; use iris_mpc_common::{ helpers::statistics::{Bucket2DResult, BucketResult, BucketStatistics, BucketStatistics2D}, job::Eye, @@ -14,6 +15,7 @@ mod tests { // Create a struct with some data let stats = BucketStatistics { + operation: Operation::Uniqueness, buckets: vec![ BucketResult { count: 10, @@ -133,6 +135,7 @@ mod tests { hamming_distance_bucket: [0.33, 0.66], }], n_buckets: 1, + operation: Operation::Uniqueness, match_distances_buffer_size: 42, party_id: 777, eye: Eye::Right, diff --git a/iris-mpc-cpu/src/execution/hawk_main.rs b/iris-mpc-cpu/src/execution/hawk_main.rs index 2ee021c03..2be1eb952 100644 --- a/iris-mpc-cpu/src/execution/hawk_main.rs +++ b/iris-mpc-cpu/src/execution/hawk_main.rs @@ -1237,9 +1237,12 @@ impl HawkResult { matched_batch_request_ids, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth: BucketStatistics::default(), + anonymized_bucket_statistics_right_reauth: BucketStatistics::default(), anonymized_bucket_statistics_left_mirror: BucketStatistics::default(), // TODO. anonymized_bucket_statistics_right_mirror: BucketStatistics::default(), // TODO. anonymized_bucket_statistics_2d: BucketStatistics2D::default(), // TODO. + anonymized_bucket_statistics_2d_reauth: BucketStatistics2D::default(), successful_reauths, reauth_target_indices: batch.reauth_target_indices, diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 459a79d80..61cb5ad98 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -38,7 +38,7 @@ use iris_mpc_common::{ inmemory_store::InMemoryStore, sha256::sha256_bytes, smpc_request::{REAUTH_MESSAGE_TYPE, RESET_CHECK_MESSAGE_TYPE, UNIQUENESS_MESSAGE_TYPE}, - statistics::{BucketStatistics, BucketStatistics2D}, + statistics::{BucketStatistics, BucketStatistics2D, Operation}, }, iris_db::{get_dummy_shares_for_deletion, iris::MATCH_THRESHOLD_RATIO}, job::{Eye, JobSubmissionHandle, ServerJobResult}, @@ -111,6 +111,13 @@ pub enum Orientation { Mirror, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RequestOp { + Uniqueness, + Reauth, + Other, +} + impl fmt::Display for Orientation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -174,13 +181,21 @@ pub struct ServerActor { buckets: ChunkShare, anonymized_bucket_statistics_left: BucketStatistics, anonymized_bucket_statistics_right: BucketStatistics, + anonymized_bucket_statistics_left_reauth: BucketStatistics, + anonymized_bucket_statistics_right_reauth: BucketStatistics, anonymized_bucket_statistics_left_mirror: BucketStatistics, anonymized_bucket_statistics_right_mirror: BucketStatistics, - // 2D anon stats buffer - both_side_match_distances_buffer: Vec, + // 2D anon stats buffers per operation + both_side_match_distances_buffer_uni: Vec, + both_side_match_distances_buffer_reauth: Vec, anonymized_bucket_statistics_2d: BucketStatistics2D, + anonymized_bucket_statistics_2d_reauth: BucketStatistics2D, full_scan_side: Eye, full_scan_side_switching_enabled: bool, + // per-op 1D publish threshold for reauth (number of match distances across all devices) + reauth_match_distances_min_count: usize, + // Mapping from batch_id -> per-query op (after filtering invalid entries) + batch_ops_map: HashMap>, } const NON_MATCH_ID: u32 = u32::MAX; @@ -197,6 +212,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent: usize, match_distances_2d_buffer_size: usize, n_buckets: usize, + reauth_match_distances_min_count: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -216,6 +232,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent, match_distances_2d_buffer_size, n_buckets, + reauth_match_distances_min_count, return_partial_results, disable_persistence, enable_debug_timing, @@ -235,6 +252,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent: usize, match_distances_2d_buffer_size: usize, n_buckets: usize, + reauth_match_distances_min_count: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -256,6 +274,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent, match_distances_2d_buffer_size, n_buckets, + reauth_match_distances_min_count, return_partial_results, disable_persistence, enable_debug_timing, @@ -277,6 +296,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent: usize, match_distances_2d_buffer_size: usize, n_buckets: usize, + reauth_match_distances_min_count: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -296,6 +316,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent, match_distances_2d_buffer_size, n_buckets, + reauth_match_distances_min_count, return_partial_results, disable_persistence, enable_debug_timing, @@ -318,6 +339,7 @@ impl ServerActor { match_distances_buffer_size_extra_percent: usize, match_distances_2d_buffer_size: usize, n_buckets: usize, + reauth_match_distances_min_count: usize, return_partial_results: bool, disable_persistence: bool, enable_debug_timing: bool, @@ -506,26 +528,73 @@ impl ServerActor { dev.synchronize().unwrap(); } - let anonymized_bucket_statistics_left = - BucketStatistics::new(match_distances_buffer_size, n_buckets, party_id, Eye::Left); + let anonymized_bucket_statistics_left = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Left, + Operation::Uniqueness, + ); - let anonymized_bucket_statistics_right = - BucketStatistics::new(match_distances_buffer_size, n_buckets, party_id, Eye::Right); + let anonymized_bucket_statistics_right = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Right, + Operation::Uniqueness, + ); - let mut anonymized_bucket_statistics_left_mirror = - BucketStatistics::new(match_distances_buffer_size, n_buckets, party_id, Eye::Left); + let anonymized_bucket_statistics_left_reauth = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Left, + Operation::Reauth, + ); + let anonymized_bucket_statistics_right_reauth = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Right, + Operation::Reauth, + ); + + let mut anonymized_bucket_statistics_left_mirror = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Left, + Operation::Uniqueness, + ); anonymized_bucket_statistics_left_mirror.is_mirror_orientation = true; - let mut anonymized_bucket_statistics_right_mirror = - BucketStatistics::new(match_distances_buffer_size, n_buckets, party_id, Eye::Right); + let mut anonymized_bucket_statistics_right_mirror = BucketStatistics::new_with_operation( + match_distances_buffer_size, + n_buckets, + party_id, + Eye::Right, + Operation::Uniqueness, + ); anonymized_bucket_statistics_right_mirror.is_mirror_orientation = true; tracing::info!("GPU actor: Initialized"); - let both_side_match_distances_buffer = + let both_side_match_distances_buffer_uni = + vec![TwoSidedDistanceCache::default(); device_manager.device_count()]; + let both_side_match_distances_buffer_reauth = vec![TwoSidedDistanceCache::default(); device_manager.device_count()]; - let anonymized_bucket_statistics_2d = - BucketStatistics2D::new(match_distances_2d_buffer_size, n_buckets, party_id); + let anonymized_bucket_statistics_2d = BucketStatistics2D::new_with_operation( + match_distances_2d_buffer_size, + n_buckets, + party_id, + Operation::Uniqueness, + ); + let anonymized_bucket_statistics_2d_reauth = BucketStatistics2D::new_with_operation( + match_distances_2d_buffer_size, + n_buckets, + party_id, + Operation::Reauth, + ); Ok(Self { party_id, @@ -575,12 +644,18 @@ impl ServerActor { match_distances_buffer_mirror: bucket_distance_cache_mirror, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, anonymized_bucket_statistics_left_mirror, anonymized_bucket_statistics_right_mirror, full_scan_side, full_scan_side_switching_enabled, - both_side_match_distances_buffer, + both_side_match_distances_buffer_uni, + both_side_match_distances_buffer_reauth, anonymized_bucket_statistics_2d, + anonymized_bucket_statistics_2d_reauth, + reauth_match_distances_min_count, + batch_ops_map: HashMap::new(), }) } @@ -639,6 +714,13 @@ impl ServerActor { .buckets .clear(); self.anonymized_bucket_statistics_2d.buckets.clear(); + self.anonymized_bucket_statistics_left_reauth + .buckets + .clear(); + self.anonymized_bucket_statistics_right_reauth + .buckets + .clear(); + self.anonymized_bucket_statistics_2d_reauth.buckets.clear(); tracing::info!( "Full batch duration took: {:?}", @@ -821,6 +903,20 @@ impl ServerActor { batch_size = valid_entry_idxs.len(); batch.retain(&valid_entry_idxs); tracing::info!("Sync and filter done in {:?}", tmp_now.elapsed()); + + // Record operation type per query for this batch id (used for anon stats splitting) + let next_batch_id = self.internal_batch_counter + 1; + let ops: Vec = batch + .request_types + .iter() + .map(|t| match t.as_str() { + UNIQUENESS_MESSAGE_TYPE => RequestOp::Uniqueness, + REAUTH_MESSAGE_TYPE => RequestOp::Reauth, + _ => RequestOp::Other, + }) + .collect(); + self.batch_ops_map.insert(next_batch_id, ops); + self.internal_batch_counter += 1; /////////////////////////////////////////////////////////////////// @@ -1623,34 +1719,58 @@ impl ServerActor { .map(|(left, right)| TwoSidedDistanceCache::merge(left, right)) .collect::>(); - for (new, cache) in two_sided_match_distances - .into_iter() - .zip(self.both_side_match_distances_buffer.iter_mut()) - { - cache.extend(new); + // Partition by operation and extend the appropriate caches + for (dev_idx, new) in two_sided_match_distances.into_iter().enumerate() { + let mut uni = TwoSidedDistanceCache::default(); + let mut reauth = TwoSidedDistanceCache::default(); + for (key, (left_vals, right_vals)) in new.map.into_iter() { + // Use one of the stored raw ids to classify the op + let sample_id = left_vals + .first() + .map(|v| v.idx) + .or_else(|| right_vals.first().map(|v| v.idx)) + .unwrap_or(0); + let mq = self.distance_comparator.query_length as u64; + let md = self.distance_comparator.max_db_size as u64; + let q_idx = (sample_id % mq) as usize; + let q_nr = q_idx / ROTATIONS; + let b_id = sample_id / (md * mq); + let classify = match self.batch_ops_map.get(&b_id) { + Some(v) => v.get(q_nr).copied().unwrap_or(RequestOp::Other), + None => RequestOp::Other, + }; + match classify { + RequestOp::Uniqueness => { + uni.map.insert(key, (left_vals, right_vals)); + } + RequestOp::Reauth => { + reauth.map.insert(key, (left_vals, right_vals)); + } + RequestOp::Other => {} + } + } + self.both_side_match_distances_buffer_uni[dev_idx].extend(uni); + self.both_side_match_distances_buffer_reauth[dev_idx].extend(reauth); } - // check if we have enough results to calculate 2D bucket statistics - let match_distance_2d_count = self - .both_side_match_distances_buffer + // check if we have enough results to calculate 2D bucket statistics (uniqueness) + let match_distance_2d_count_uni = self + .both_side_match_distances_buffer_uni .iter() .map(|x| x.len()) .sum::(); tracing::info!( - "Match distance 2D count: {match_distance_2d_count}/{}", + "Match distance 2D count (uniqueness): {match_distance_2d_count_uni}/{}", self.match_distances_2d_buffer_size ); - if match_distance_2d_count >= self.match_distances_2d_buffer_size { - tracing::info!("Calculating bucket statistics for both sides"); - let mut both_side_match_distances_buffer = + if match_distance_2d_count_uni >= self.match_distances_2d_buffer_size { + tracing::info!("Calculating 2D bucket statistics (uniqueness)"); + let mut buffer = vec![TwoSidedDistanceCache::default(); self.device_manager.device_count()]; - std::mem::swap( - &mut self.both_side_match_distances_buffer, - &mut both_side_match_distances_buffer, - ); + std::mem::swap(&mut self.both_side_match_distances_buffer_uni, &mut buffer); let min_distance_cache = TwoSidedDistanceCache::into_min_distance_cache( - both_side_match_distances_buffer, + buffer, &mut self.phase2_2d_buckets, &self.streams[0], ); @@ -1667,23 +1787,6 @@ impl ServerActor { &self.streams[0], &thresholds, ); - tracing::info!("Bucket statistics calculated"); - let mut buckets_2d_string = String::new(); - for (i, bucket) in buckets_2d.iter().enumerate() { - let left_idx = i / self.n_buckets; - let right_idx = i % self.n_buckets; - let step = 0.375 / self.n_buckets as f64; - buckets_2d_string += &format!( - "Bucket ({:.3}-{:.3},{:.3}-{:.3}): {}\n", - left_idx as f64 * step, - (left_idx + 1) as f64 * step, - right_idx as f64 * step, - (right_idx + 1) as f64 * step, - bucket - ); - } - tracing::info!("Bucket statistics calculated:\n{}", buckets_2d_string); - // Fill the 2D anonymized statistics structure for propagation self.anonymized_bucket_statistics_2d.fill_buckets( &buckets_2d, @@ -1693,6 +1796,52 @@ impl ServerActor { ); } + // check if we have enough results to calculate 2D bucket statistics (reauth) + let match_distance_2d_count_reauth = self + .both_side_match_distances_buffer_reauth + .iter() + .map(|x| x.len()) + .sum::(); + tracing::info!( + "Match distance 2D count (reauth): {match_distance_2d_count_reauth}/{}", + self.match_distances_2d_buffer_size + ); + if match_distance_2d_count_reauth >= self.match_distances_2d_buffer_size { + tracing::info!("Calculating 2D bucket statistics (reauth)"); + let mut buffer = + vec![TwoSidedDistanceCache::default(); self.device_manager.device_count()]; + std::mem::swap( + &mut self.both_side_match_distances_buffer_reauth, + &mut buffer, + ); + + let min_distance_cache = TwoSidedDistanceCache::into_min_distance_cache( + buffer, + &mut self.phase2_2d_buckets, + &self.streams[0], + ); + let thresholds = (1..=self.n_buckets) + .map(|x: usize| { + Circuits::translate_threshold_a( + MATCH_THRESHOLD_RATIO / (self.n_buckets as f64) * (x as f64), + ) as u16 + }) + .collect::>(); + + let buckets_2d = min_distance_cache.compute_buckets( + &mut self.phase2_2d_buckets, + &self.streams[0], + &thresholds, + ); + // Fill the 2D anonymized statistics structure for propagation (reauth) + self.anonymized_bucket_statistics_2d_reauth.fill_buckets( + &buckets_2d, + MATCH_THRESHOLD_RATIO, + self.anonymized_bucket_statistics_left + .next_start_time_utc_timestamp, + ); + } + // Instead of sending to return_channel, we'll return this at the end let result = ServerJobResult { merged_results, @@ -1720,6 +1869,15 @@ impl ServerActor { anonymized_bucket_statistics_left: self.anonymized_bucket_statistics_left.clone(), anonymized_bucket_statistics_right: self.anonymized_bucket_statistics_right.clone(), anonymized_bucket_statistics_2d: self.anonymized_bucket_statistics_2d.clone(), + anonymized_bucket_statistics_left_reauth: self + .anonymized_bucket_statistics_left_reauth + .clone(), + anonymized_bucket_statistics_right_reauth: self + .anonymized_bucket_statistics_right_reauth + .clone(), + anonymized_bucket_statistics_2d_reauth: self + .anonymized_bucket_statistics_2d_reauth + .clone(), anonymized_bucket_statistics_left_mirror: self .anonymized_bucket_statistics_left_mirror .clone(), @@ -1853,7 +2011,7 @@ impl ServerActor { self.device_manager.await_streams(streams); - let indices = match_distances_indices + let indices_vecs = match_distances_indices .iter() .enumerate() .map(|(i, x)| { @@ -1861,10 +2019,10 @@ impl ServerActor { }) .collect::>(); - let resort_indices = (0..indices.len()) + let resort_indices_all = (0..indices_vecs.len()) .map(|i| { - let mut resort_indices = (0..indices[i].len()).collect::>(); - resort_indices.sort_by_key(|&j| indices[i][j]); + let mut resort_indices = (0..indices_vecs[i].len()).collect::>(); + resort_indices.sort_by_key(|&j| indices_vecs[i][j]); resort_indices }) .collect::>(); @@ -1881,51 +2039,96 @@ impl ServerActor { } result } - // sort all indices, and create bitmaps from them - let indices_bitmaps = indices - .into_iter() - .map(|mut x| { - x.sort(); - x.truncate(self.match_distances_buffer_size); - x - }) - .map(|mut sorted| { - for id in &mut sorted { - // re-map the ids to remove the ROTATION aspect from them + // Helper: build per-op resort indices and bitmasks + let mq = self.distance_comparator.query_length as u64; // max_batch_size * ALL_ROTATIONS + let md = self.distance_comparator.max_db_size as u64; + let chunk_size_words = self.match_distances_buffer_size.div_ceil(64); + + let build_subset = |want_op: RequestOp| { + let mut subset_resort: Vec> = + Vec::with_capacity(resort_indices_all.len()); + let mut subset_bitmasks: Vec> = + Vec::with_capacity(resort_indices_all.len()); + let mut subset_lengths: Vec = Vec::with_capacity(resort_indices_all.len()); + + for (dev_i, order) in resort_indices_all.iter().enumerate() { + let idx_vec = &indices_vecs[dev_i]; + let mut filtered_positions = Vec::with_capacity(order.len()); + let mut filtered_ids = Vec::with_capacity(order.len()); + for &pos in order.iter() { + let id_raw = idx_vec[pos]; + let q_idx = (id_raw % mq) as usize; + let q_nr = q_idx / ROTATIONS; + let b_id = id_raw / (md * mq); + let classify = match self.batch_ops_map.get(&b_id) { + Some(v) => v.get(q_nr).copied().unwrap_or(RequestOp::Other), + None => RequestOp::Other, + }; + if classify == want_op { + filtered_positions.push(pos); + filtered_ids.push(id_raw); + } + } + + // truncate to buffer size + let truncate_len = filtered_positions + .len() + .min(self.match_distances_buffer_size); + filtered_positions.truncate(truncate_len); + filtered_ids.truncate(truncate_len); + + // remove rotations for grouping + for id in &mut filtered_ids { *id /= ROTATIONS as u64; } - sorted - }) - .map(|sorted| ids_to_bitvec(&sorted)) - .collect_vec(); + let bitvec = if filtered_ids.is_empty() { + vec![0u64; chunk_size_words] + } else { + let mut bv = ids_to_bitvec(&filtered_ids); + if bv.len() < chunk_size_words { + bv.resize(chunk_size_words, 0); + } + bv + }; + + subset_lengths.push(truncate_len); + subset_resort.push(filtered_positions); + subset_bitmasks.push(bitvec); + } + + (subset_resort, subset_bitmasks, subset_lengths) + }; + + // Always compute Uniqueness buckets + let (resort_uni, bitmasks_uni, _) = build_subset(RequestOp::Uniqueness); let shares = sort_shares_by_indices( &self.device_manager, - &resort_indices, + &resort_uni, match_distances_buffers_codes, + // Use max per-device length, function slices per device internally self.match_distances_buffer_size, streams, ); - let match_distances_buffers_codes_view = shares.iter().map(|x| x.as_view()).collect::>(); - let shares = sort_shares_by_indices( &self.device_manager, - &resort_indices, + &resort_uni, match_distances_buffers_masks, self.match_distances_buffer_size, streams, ); - let match_distances_buffers_masks_view = shares.iter().map(|x| x.as_view()).collect::>(); + // Reset buckets before computing a new set + reset_single_share(self.device_manager.devices(), &self.buckets, 0, streams, 0); self.phase2_buckets .compare_multiple_thresholds_while_aggregating_per_query( &match_distances_buffers_codes_view, &match_distances_buffers_masks_view, - &indices_bitmaps, + &bitmasks_uni, streams, &(1..=self.n_buckets) .map(|x: usize| { @@ -1937,38 +2140,36 @@ impl ServerActor { &mut self.buckets, ); - let buckets = self.phase2_buckets.open_buckets(&self.buckets, streams); - - tracing::info!("Buckets: {:?}", buckets); + let buckets_uni = self.phase2_buckets.open_buckets(&self.buckets, streams); match (eye_db, orientation) { (Eye::Left, Orientation::Normal) => { self.anonymized_bucket_statistics_left.fill_buckets( - &buckets, + &buckets_uni, MATCH_THRESHOLD_RATIO, self.anonymized_bucket_statistics_left .next_start_time_utc_timestamp, ); tracing::info!( - "Normal bucket results (left):\n{}", + "Normal bucket results (left, uniqueness):\n{}", self.anonymized_bucket_statistics_left ); } (Eye::Right, Orientation::Normal) => { self.anonymized_bucket_statistics_right.fill_buckets( - &buckets, + &buckets_uni, MATCH_THRESHOLD_RATIO, self.anonymized_bucket_statistics_right .next_start_time_utc_timestamp, ); tracing::info!( - "Normal bucket results (right):\n{}", + "Normal bucket results (right, uniqueness):\n{}", self.anonymized_bucket_statistics_right ); } (Eye::Left, Orientation::Mirror) => { self.anonymized_bucket_statistics_left_mirror.fill_buckets( - &buckets, + &buckets_uni, MATCH_THRESHOLD_RATIO, self.anonymized_bucket_statistics_left_mirror .next_start_time_utc_timestamp, @@ -1980,7 +2181,7 @@ impl ServerActor { } (Eye::Right, Orientation::Mirror) => { self.anonymized_bucket_statistics_right_mirror.fill_buckets( - &buckets, + &buckets_uni, MATCH_THRESHOLD_RATIO, self.anonymized_bucket_statistics_right_mirror .next_start_time_utc_timestamp, @@ -1992,6 +2193,77 @@ impl ServerActor { } } + // Compute Reauth buckets only for Normal orientation and above threshold + if orientation == Orientation::Normal { + let (resort_reauth, bitmasks_reauth, lengths_reauth) = + build_subset(RequestOp::Reauth); + let total_reauth_count: usize = lengths_reauth.iter().sum(); + tracing::info!( + "Reauth distances collected across devices: {} (min required: {})", + total_reauth_count, + self.reauth_match_distances_min_count + ); + if total_reauth_count >= self.reauth_match_distances_min_count { + let shares = sort_shares_by_indices( + &self.device_manager, + &resort_reauth, + match_distances_buffers_codes, + self.match_distances_buffer_size, + streams, + ); + let match_distances_buffers_codes_view = + shares.iter().map(|x| x.as_view()).collect::>(); + let shares = sort_shares_by_indices( + &self.device_manager, + &resort_reauth, + match_distances_buffers_masks, + self.match_distances_buffer_size, + streams, + ); + let match_distances_buffers_masks_view = + shares.iter().map(|x| x.as_view()).collect::>(); + + reset_single_share(self.device_manager.devices(), &self.buckets, 0, streams, 0); + self.phase2_buckets + .compare_multiple_thresholds_while_aggregating_per_query( + &match_distances_buffers_codes_view, + &match_distances_buffers_masks_view, + &bitmasks_reauth, + streams, + &(1..=self.n_buckets) + .map(|x: usize| { + Circuits::translate_threshold_a( + MATCH_THRESHOLD_RATIO / (self.n_buckets as f64) + * (x as f64), + ) as u16 + }) + .collect::>(), + &mut self.buckets, + ); + let buckets_reauth = self.phase2_buckets.open_buckets(&self.buckets, streams); + match eye_db { + Eye::Left => { + self.anonymized_bucket_statistics_left_reauth.fill_buckets( + &buckets_reauth, + MATCH_THRESHOLD_RATIO, + self.anonymized_bucket_statistics_left_reauth + .next_start_time_utc_timestamp, + ); + } + Eye::Right => { + self.anonymized_bucket_statistics_right_reauth.fill_buckets( + &buckets_reauth, + MATCH_THRESHOLD_RATIO, + self.anonymized_bucket_statistics_right_reauth + .next_start_time_utc_timestamp, + ); + } + } + } else { + tracing::info!("Reauth distances below threshold, skipping 1D reauth stats"); + } + } + let reset_all_buffers = |counter: &[CudaSlice], indices: &[CudaSlice], @@ -2003,7 +2275,7 @@ impl ServerActor { reset_share(self.device_manager.devices(), codes, 0xff, streams); }; - // Reset all buffers used in this calculation + // Reset all device-side buffers used in this calculation reset_all_buffers( match_distances_counters, match_distances_indices, @@ -3377,15 +3649,22 @@ fn sort_shares_by_indices( .iter() .map(|&j| a[i][j]) .collect::>(); - let a = htod_on_stream_sync(&new_a[..length], &device_manager.device(i), &streams[i]) - .unwrap(); + let slice_len = new_a.len().min(length); + let a = + htod_on_stream_sync(&new_a[..slice_len], &device_manager.device(i), &streams[i]) + .unwrap(); let new_b = resort_indices[i] .iter() .map(|&j| b[i][j]) .collect::>(); - let b = htod_on_stream_sync(&new_b[..length], &device_manager.device(i), &streams[i]) - .unwrap(); + let slice_len_b = new_b.len().min(length); + let b = htod_on_stream_sync( + &new_b[..slice_len_b], + &device_manager.device(i), + &streams[i], + ) + .unwrap(); ChunkShare::new(a, b) }) diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index 739eade3a..f627b764a 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -19,6 +19,7 @@ mod e2e_anon_stats_test { const MAX_BATCH_SIZE: usize = 64; const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 5000; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; @@ -114,6 +115,7 @@ mod e2e_anon_stats_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, @@ -149,6 +151,7 @@ mod e2e_anon_stats_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, @@ -184,6 +187,7 @@ mod e2e_anon_stats_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 907a041d2..016e45795 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -21,6 +21,7 @@ mod e2e_test { const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 7; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 100; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; const MAX_DELETIONS_PER_BATCH: usize = 10; const MAX_RESET_UPDATES_PER_BATCH: usize = 10; @@ -112,6 +113,7 @@ mod e2e_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, @@ -147,6 +149,7 @@ mod e2e_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, @@ -182,6 +185,7 @@ mod e2e_test { MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT, MATCH_DISTANCES_2D_BUFFER_SIZE, N_BUCKETS, + REAUTH_MATCH_DISTANCES_MIN_COUNT, true, false, false, diff --git a/iris-mpc/bin/server.rs b/iris-mpc/bin/server.rs index 7d5c34cf6..ab25bfa73 100644 --- a/iris-mpc/bin/server.rs +++ b/iris-mpc/bin/server.rs @@ -1309,6 +1309,7 @@ async fn server_main(config: Config) -> Result<()> { config.match_distances_buffer_size_extra_percent, config.match_distances_2d_buffer_size, config.n_buckets, + config.reauth_match_distances_min_count, config.return_partial_results, config.disable_persistence, config.enable_debug_timing, @@ -1398,6 +1399,8 @@ async fn server_main(config: Config) -> Result<()> { matched_batch_request_ids, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, anonymized_bucket_statistics_left_mirror, anonymized_bucket_statistics_right_mirror, successful_reauths, @@ -1410,6 +1413,7 @@ async fn server_main(config: Config) -> Result<()> { actor_data: _, full_face_mirror_attack_detected, anonymized_bucket_statistics_2d, + anonymized_bucket_statistics_2d_reauth, }) = rx.recv().await { let dummy_deletion_shares = get_dummy_shares_for_deletion(party_id); @@ -1849,6 +1853,35 @@ async fn server_main(config: Config) -> Result<()> { .await?; } + // Send reauth anonymized statistics (normal orientation only) + if (config_bg.enable_sending_anonymized_stats_message) + && (!anonymized_bucket_statistics_left_reauth.buckets.is_empty() + || !anonymized_bucket_statistics_right_reauth.buckets.is_empty()) + { + tracing::info!("Sending anonymized stats results (reauth)"); + let anonymized_statistics_results = [ + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, + ]; + let anonymized_statistics_results = anonymized_statistics_results + .iter() + .map(|anonymized_bucket_statistics| { + serde_json::to_string(anonymized_bucket_statistics) + .wrap_err("failed to serialize anonymized statistics result (reauth)") + }) + .collect::>>()?; + + send_results_to_sns( + anonymized_statistics_results, + &metadata, + &sns_client_bg, + &config_bg, + &anonymized_statistics_attributes, + ANONYMIZED_STATISTICS_MESSAGE_TYPE, + ) + .await?; + } + // Send mirror orientation statistics separately with their own flag if (config_bg.enable_sending_mirror_anonymized_stats_message) && (!anonymized_bucket_statistics_left_mirror.buckets.is_empty() @@ -1918,6 +1951,42 @@ async fn server_main(config: Config) -> Result<()> { .await?; } + // Send 2D anonymized statistics for reauth if present + if config_bg.enable_sending_anonymized_stats_2d_message + && !anonymized_bucket_statistics_2d_reauth.buckets.is_empty() + { + tracing::info!("Sending 2D anonymized stats results (reauth)"); + let serialized = serde_json::to_string(&anonymized_bucket_statistics_2d_reauth) + .wrap_err("failed to serialize 2D anonymized statistics result (reauth)")?; + + let now_ms = Utc::now().timestamp_millis(); + let sha = iris_mpc_common::helpers::sha256::sha256_bytes(&serialized); + let content_hash = hex::encode(sha); + let s3_key = format!("stats2d/{}_{}_reauth.json", now_ms, content_hash); + + upload_file_to_s3( + &config_bg.sns_buffer_bucket_name, + &s3_key, + s3_client_bg.clone(), + serialized.as_bytes(), + ) + .await + .wrap_err("failed to upload 2D anonymized statistics (reauth) to s3")?; + + let payload = serde_json::to_string(&serde_json::json!({ + "s3_key": s3_key, + }))?; + send_results_to_sns( + vec![payload], + &metadata, + &sns_client_bg, + &config_bg, + &anonymized_statistics_2d_attributes, + ANONYMIZED_STATISTICS_2D_MESSAGE_TYPE, + ) + .await?; + } + shutdown_handler_bg.decrement_batches_pending_completion(); } diff --git a/iris-mpc/src/services/processors/job.rs b/iris-mpc/src/services/processors/job.rs index e5237bad0..2dcc8bf3b 100644 --- a/iris-mpc/src/services/processors/job.rs +++ b/iris-mpc/src/services/processors/job.rs @@ -55,6 +55,8 @@ pub async fn process_job_result( matched_batch_request_ids, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, successful_reauths, reauth_target_indices, reauth_or_rule_used, @@ -446,6 +448,35 @@ pub async fn process_job_result( .await?; } + // Send reauth anonymized statistics if present + if (config.enable_sending_anonymized_stats_message) + && (!anonymized_bucket_statistics_left_reauth.buckets.is_empty() + || !anonymized_bucket_statistics_right_reauth.buckets.is_empty()) + { + tracing::info!("Sending anonymized stats results (reauth)"); + let anonymized_statistics_results = [ + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, + ]; + let anonymized_statistics_results = anonymized_statistics_results + .iter() + .map(|anonymized_bucket_statistics| { + serde_json::to_string(anonymized_bucket_statistics) + .wrap_err("failed to serialize anonymized statistics result (reauth)") + }) + .collect::>>()?; + + send_results_to_sns( + anonymized_statistics_results, + &metadata, + sns_client, + config, + anonymized_statistics_attributes, + ANONYMIZED_STATISTICS_MESSAGE_TYPE, + ) + .await?; + } + tracing::info!("Sending {} reset check results", reset_check_results.len()); send_results_to_sns( reset_check_results, From 6b15ee62883b0a85a4e63fc604db50e6c16a9cfe Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Fri, 19 Sep 2025 16:33:37 +0200 Subject: [PATCH 02/16] fmt --- iris-mpc-common/src/helpers/statistics.rs | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/iris-mpc-common/src/helpers/statistics.rs b/iris-mpc-common/src/helpers/statistics.rs index 0c258a779..9746e0ae3 100644 --- a/iris-mpc-common/src/helpers/statistics.rs +++ b/iris-mpc-common/src/helpers/statistics.rs @@ -111,12 +111,7 @@ impl BucketStatistics { eye: Eye, operation: Operation, ) -> Self { - let mut bs = Self::new( - match_distances_buffer_size, - n_buckets, - party_id, - eye, - ); + let mut bs = Self::new(match_distances_buffer_size, n_buckets, party_id, eye); bs.operation = operation; bs } @@ -277,11 +272,7 @@ impl BucketStatistics2D { party_id: usize, operation: Operation, ) -> Self { - let mut bs = Self::new( - match_distances_buffer_size, - n_buckets_per_side, - party_id, - ); + let mut bs = Self::new(match_distances_buffer_size, n_buckets_per_side, party_id); bs.operation = operation; bs } From 6c2f9ce92e97b7117fee5119b2a321ac817f7dea Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Sat, 20 Sep 2025 14:58:06 +0200 Subject: [PATCH 03/16] test fix --- iris-mpc-common/tests/statistics.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/iris-mpc-common/tests/statistics.rs b/iris-mpc-common/tests/statistics.rs index a511a9a9e..f31dc728a 100644 --- a/iris-mpc-common/tests/statistics.rs +++ b/iris-mpc-common/tests/statistics.rs @@ -76,6 +76,7 @@ mod tests { assert_eq!(value["n_buckets"], json!(2)); assert_eq!(value["match_distances_buffer_size"], json!(128)); assert_eq!(value["is_mirror_orientation"], json!(false)); + assert_eq!(value["operation"], json!("Uniqueness")); } #[test] @@ -96,7 +97,8 @@ mod tests { "eye": "Left", "start_time_utc_timestamp": 1700000000, "end_time_utc_timestamp": null, - "is_mirror_orientation": false + "is_mirror_orientation": false, + "operation": "Uniqueness" }) .to_string(); From c065ae381a43daff7d3b226c133b0484764f816e Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Sun, 21 Sep 2025 22:14:28 +0200 Subject: [PATCH 04/16] gpu test fix --- iris-mpc-gpu/src/server/actor.rs | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 61cb5ad98..c8b6aaf68 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -1721,7 +1721,7 @@ impl ServerActor { // Partition by operation and extend the appropriate caches for (dev_idx, new) in two_sided_match_distances.into_iter().enumerate() { - let mut uni = TwoSidedDistanceCache::default(); + let mut uniqueness = TwoSidedDistanceCache::default(); let mut reauth = TwoSidedDistanceCache::default(); for (key, (left_vals, right_vals)) in new.map.into_iter() { // Use one of the stored raw ids to classify the op @@ -1741,7 +1741,7 @@ impl ServerActor { }; match classify { RequestOp::Uniqueness => { - uni.map.insert(key, (left_vals, right_vals)); + uniqueness.map.insert(key, (left_vals, right_vals)); } RequestOp::Reauth => { reauth.map.insert(key, (left_vals, right_vals)); @@ -1749,7 +1749,7 @@ impl ServerActor { RequestOp::Other => {} } } - self.both_side_match_distances_buffer_uni[dev_idx].extend(uni); + self.both_side_match_distances_buffer_uni[dev_idx].extend(uniqueness); self.both_side_match_distances_buffer_reauth[dev_idx].extend(reauth); } @@ -3649,22 +3649,20 @@ fn sort_shares_by_indices( .iter() .map(|&j| a[i][j]) .collect::>(); - let slice_len = new_a.len().min(length); - let a = - htod_on_stream_sync(&new_a[..slice_len], &device_manager.device(i), &streams[i]) - .unwrap(); + // Pad to exactly `length` entries (zeros) to satisfy kernel chunking (multiple of 64) + let mut pad_a = vec![0u16; length]; + let copy_len = new_a.len().min(length); + pad_a[..copy_len].copy_from_slice(&new_a[..copy_len]); + let a = htod_on_stream_sync(&pad_a, &device_manager.device(i), &streams[i]).unwrap(); let new_b = resort_indices[i] .iter() .map(|&j| b[i][j]) .collect::>(); - let slice_len_b = new_b.len().min(length); - let b = htod_on_stream_sync( - &new_b[..slice_len_b], - &device_manager.device(i), - &streams[i], - ) - .unwrap(); + let mut pad_b = vec![0u16; length]; + let copy_len_b = new_b.len().min(length); + pad_b[..copy_len_b].copy_from_slice(&new_b[..copy_len_b]); + let b = htod_on_stream_sync(&pad_b, &device_manager.device(i), &streams[i]).unwrap(); ChunkShare::new(a, b) }) From 39643cdef76c832eb067e983f5ccc1bbb5372995 Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 06:59:30 +0200 Subject: [PATCH 05/16] gpu tests extensions for reauth anon stats buckets --- iris-mpc-common/src/test.rs | 30 ++++++++++++++++++++++++++++ iris-mpc-gpu/tests/e2e-anon-stats.rs | 12 ++++++----- iris-mpc-gpu/tests/e2e.rs | 3 ++- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 4336e615b..4e7288a89 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1256,6 +1256,8 @@ impl TestCaseGenerator { matched_batch_request_ids, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, anonymized_bucket_statistics_left_mirror, anonymized_bucket_statistics_right_mirror, successful_reauths, @@ -1265,6 +1267,19 @@ impl TestCaseGenerator { .. } = res; + // Operation tagging checks for 1D stats + use crate::helpers::statistics::Operation; + assert_eq!(anonymized_bucket_statistics_left.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_right.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_left_mirror.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_right_mirror.operation, Operation::Uniqueness); + if !anonymized_bucket_statistics_left_reauth.buckets.is_empty() + || !anonymized_bucket_statistics_right_reauth.buckets.is_empty() + { + assert_eq!(anonymized_bucket_statistics_left_reauth.operation, Operation::Reauth); + assert_eq!(anonymized_bucket_statistics_right_reauth.operation, Operation::Reauth); + } + if let Some(bucket_statistic_parameters) = &self.bucket_statistic_parameters { // Check that normal orientation statistics have is_mirror_orientation set to false assert!(!anonymized_bucket_statistics_left.is_mirror_orientation, @@ -1809,6 +1824,8 @@ impl SimpleAnonStatsTestGenerator { matched_batch_request_ids, anonymized_bucket_statistics_left, anonymized_bucket_statistics_right, + anonymized_bucket_statistics_left_reauth, + anonymized_bucket_statistics_right_reauth, anonymized_bucket_statistics_left_mirror, anonymized_bucket_statistics_right_mirror, .. @@ -1846,6 +1863,19 @@ impl SimpleAnonStatsTestGenerator { self.bucket_statistic_parameters.num_buckets, )?; + // Operation tagging checks for 1D stats + use crate::helpers::statistics::Operation; + assert_eq!(anonymized_bucket_statistics_left.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_right.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_left_mirror.operation, Operation::Uniqueness); + assert_eq!(anonymized_bucket_statistics_right_mirror.operation, Operation::Uniqueness); + if !anonymized_bucket_statistics_left_reauth.buckets.is_empty() + || !anonymized_bucket_statistics_right_reauth.buckets.is_empty() + { + assert_eq!(anonymized_bucket_statistics_left_reauth.operation, Operation::Reauth); + assert_eq!(anonymized_bucket_statistics_right_reauth.operation, Operation::Reauth); + } + if !anonymized_bucket_statistics_left.is_empty() { tracing::info!("Got anonymized bucket statistics for left side, checking..."); tracing::info!("Plain distances left : {:?}", self.plain_distances_left); diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index f627b764a..12e66a084 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -1,9 +1,10 @@ -#[cfg(feature = "gpu_dependent")] +// #[cfg(feature = "gpu_dependent")] mod e2e_anon_stats_test { use cudarc::nccl::Id; use eyre::Result; use iris_mpc_common::{ helpers::inmemory_store::InMemoryStore, + helpers::statistics::Operation, job::Eye, test::{generate_full_test_db, load_test_db, SimpleAnonStatsTestGenerator}, }; @@ -19,7 +20,8 @@ mod e2e_anon_stats_test { const MAX_BATCH_SIZE: usize = 64; const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; - const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; + // set to a small number for fast test; here no reauth is generated anyway + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 1; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 5000; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; @@ -223,9 +225,9 @@ mod e2e_anon_stats_test { drop(handle1); drop(handle2); - actor0_task.await.unwrap(); - actor1_task.await.unwrap(); - actor2_task.await.unwrap(); + actor0_task.await?; + actor1_task.await?; + actor2_task.await?; Ok(()) } diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index 016e45795..24a1f392b 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -21,7 +21,8 @@ mod e2e_test { const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 7; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 100; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; - const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; + // Keep low to ensure reauth stats get produced during the test + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 1; const MAX_DELETIONS_PER_BATCH: usize = 10; const MAX_RESET_UPDATES_PER_BATCH: usize = 10; From 8e77a21a7e804a014d64437392198fc5e0e5224f Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 07:14:08 +0200 Subject: [PATCH 06/16] generate reaith batches --- iris-mpc-common/src/test.rs | 127 ++++++++++++++++++++------- iris-mpc-gpu/tests/e2e-anon-stats.rs | 3 +- 2 files changed, 94 insertions(+), 36 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 4e7288a89..7a2a0d34c 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1269,15 +1269,33 @@ impl TestCaseGenerator { // Operation tagging checks for 1D stats use crate::helpers::statistics::Operation; - assert_eq!(anonymized_bucket_statistics_left.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_right.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_left_mirror.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_right_mirror.operation, Operation::Uniqueness); + assert_eq!( + anonymized_bucket_statistics_left.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_right.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_left_mirror.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_right_mirror.operation, + Operation::Uniqueness + ); if !anonymized_bucket_statistics_left_reauth.buckets.is_empty() || !anonymized_bucket_statistics_right_reauth.buckets.is_empty() { - assert_eq!(anonymized_bucket_statistics_left_reauth.operation, Operation::Reauth); - assert_eq!(anonymized_bucket_statistics_right_reauth.operation, Operation::Reauth); + assert_eq!( + anonymized_bucket_statistics_left_reauth.operation, + Operation::Reauth + ); + assert_eq!( + anonymized_bucket_statistics_right_reauth.operation, + Operation::Reauth + ); } if let Some(bucket_statistic_parameters) = &self.bucket_statistic_parameters { @@ -1648,28 +1666,50 @@ impl SimpleAnonStatsTestGenerator { } } - fn generate_query(&mut self) -> Option<(String, E2ETemplate, String)> { + fn generate_query(&mut self) -> Option<(String, E2ETemplate, String, Option)> { + use crate::helpers::smpc_request::REAUTH_MESSAGE_TYPE; let request_id = Uuid::new_v4(); let db_index = self.rng.gen_range(0..self.db_state.len()); - let approx_diff_factor = self.rng.gen_range(0.0..0.35); - let mut template = E2ETemplate { - left: self.db_state.plain_dbs[0].db[db_index] - .get_similar_iris(&mut self.rng, approx_diff_factor), - right: self.db_state.plain_dbs[1].db[db_index] - .get_similar_iris(&mut self.rng, approx_diff_factor), + + // Occasionally generate a REAUTH query by targeting the exact DB entry. + // Otherwise generate a UNIQ query similar to a DB entry. + let is_reauth = self.rng.gen_bool(0.2); + let mut template = if is_reauth { + E2ETemplate { + left: self.db_state.plain_dbs[0].db[db_index].clone(), + right: self.db_state.plain_dbs[1].db[db_index].clone(), + } + } else { + let approx_diff_factor = self.rng.gen_range(0.0..0.35); + E2ETemplate { + left: self.db_state.plain_dbs[0].db[db_index] + .get_similar_iris(&mut self.rng, approx_diff_factor), + right: self.db_state.plain_dbs[1].db[db_index] + .get_similar_iris(&mut self.rng, approx_diff_factor), + } }; - let rotation = self.rng.gen_range(0..ROTATIONS); // Rotate the query iris codes + let rotation = self.rng.gen_range(0..ROTATIONS); template.left = template.left.all_rotations()[rotation].clone(); let rotation = self.rng.gen_range(0..ROTATIONS); template.right = template.right.all_rotations()[rotation].clone(); - Some(( - request_id.to_string(), - template, - UNIQUENESS_MESSAGE_TYPE.to_string(), - )) + if is_reauth { + Some(( + request_id.to_string(), + template, + REAUTH_MESSAGE_TYPE.to_string(), + Some(db_index as u32), + )) + } else { + Some(( + request_id.to_string(), + template, + UNIQUENESS_MESSAGE_TYPE.to_string(), + None, + )) + } } #[allow(clippy::type_complexity)] @@ -1685,12 +1725,13 @@ impl SimpleAnonStatsTestGenerator { batch1.full_face_mirror_attacks_detection_enabled = true; batch2.full_face_mirror_attacks_detection_enabled = true; - let (request_id, e2e_template, message_type) = match self.generate_query() { - Some((request_id, e2e_template, message_type)) => { - (request_id, e2e_template, message_type) - } - None => return Ok(None), - }; + let (request_id, e2e_template, message_type, reauth_target_idx) = + match self.generate_query() { + Some((request_id, e2e_template, message_type, reauth_target_idx)) => { + (request_id, e2e_template, message_type, reauth_target_idx) + } + None => return Ok(None), + }; requests.insert(request_id.to_string(), e2e_template.clone()); @@ -1703,7 +1744,7 @@ impl SimpleAnonStatsTestGenerator { 0, shared_template.clone(), vec![], - None, + reauth_target_idx.as_ref(), false, message_type.clone(), )?; @@ -1715,7 +1756,7 @@ impl SimpleAnonStatsTestGenerator { 1, shared_template.clone(), vec![], - None, + reauth_target_idx.as_ref(), false, message_type.clone(), )?; @@ -1727,7 +1768,7 @@ impl SimpleAnonStatsTestGenerator { 2, shared_template, vec![], - None, + reauth_target_idx.as_ref(), false, message_type, )?; @@ -1865,15 +1906,33 @@ impl SimpleAnonStatsTestGenerator { // Operation tagging checks for 1D stats use crate::helpers::statistics::Operation; - assert_eq!(anonymized_bucket_statistics_left.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_right.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_left_mirror.operation, Operation::Uniqueness); - assert_eq!(anonymized_bucket_statistics_right_mirror.operation, Operation::Uniqueness); + assert_eq!( + anonymized_bucket_statistics_left.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_right.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_left_mirror.operation, + Operation::Uniqueness + ); + assert_eq!( + anonymized_bucket_statistics_right_mirror.operation, + Operation::Uniqueness + ); if !anonymized_bucket_statistics_left_reauth.buckets.is_empty() || !anonymized_bucket_statistics_right_reauth.buckets.is_empty() { - assert_eq!(anonymized_bucket_statistics_left_reauth.operation, Operation::Reauth); - assert_eq!(anonymized_bucket_statistics_right_reauth.operation, Operation::Reauth); + assert_eq!( + anonymized_bucket_statistics_left_reauth.operation, + Operation::Reauth + ); + assert_eq!( + anonymized_bucket_statistics_right_reauth.operation, + Operation::Reauth + ); } if !anonymized_bucket_statistics_left.is_empty() { diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index 12e66a084..ca86ccdf9 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -1,10 +1,9 @@ -// #[cfg(feature = "gpu_dependent")] +#[cfg(feature = "gpu_dependent")] mod e2e_anon_stats_test { use cudarc::nccl::Id; use eyre::Result; use iris_mpc_common::{ helpers::inmemory_store::InMemoryStore, - helpers::statistics::Operation, job::Eye, test::{generate_full_test_db, load_test_db, SimpleAnonStatsTestGenerator}, }; From f514b1fe5a366e8a9b169e0d37f0541b42cf67c3 Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 11:15:39 +0200 Subject: [PATCH 07/16] add test case for reauth anon stats --- iris-mpc-common/src/test.rs | 188 ++++++++++++++++++++------- iris-mpc-gpu/tests/e2e-anon-stats.rs | 8 +- 2 files changed, 145 insertions(+), 51 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 7a2a0d34c..6f87d8f67 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1301,14 +1301,14 @@ impl TestCaseGenerator { if let Some(bucket_statistic_parameters) = &self.bucket_statistic_parameters { // Check that normal orientation statistics have is_mirror_orientation set to false assert!(!anonymized_bucket_statistics_left.is_mirror_orientation, - "Normal orientation left statistics should have is_mirror_orientation = false"); + "Normal orientation left statistics should have is_mirror_orientation = false"); assert!(!anonymized_bucket_statistics_right.is_mirror_orientation, - "Normal orientation right statistics should have is_mirror_orientation = false"); + "Normal orientation right statistics should have is_mirror_orientation = false"); // Check that mirror orientation statistics have is_mirror_orientation set to true assert!(anonymized_bucket_statistics_left_mirror.is_mirror_orientation, - "Mirror orientation left statistics should have is_mirror_orientation = true"); + "Mirror orientation left statistics should have is_mirror_orientation = true"); assert!(anonymized_bucket_statistics_right_mirror.is_mirror_orientation, - "Mirror orientation right statistics should have is_mirror_orientation = true"); + "Mirror orientation right statistics should have is_mirror_orientation = true"); // Perform some very basic checks on the bucket statistics, not checking the results here check_bucket_statistics( @@ -1646,7 +1646,9 @@ pub fn load_test_db(party_db: &PartyDb, loader: &mut impl InMemoryStore) { pub struct SimpleAnonStatsTestGenerator { db_state: TestDb, plain_distances_left: Vec, + plain_distances_left_reauth: Vec, plain_distances_right: Vec, + plain_distances_right_reauth: Vec, plain_distances_left_mirror: Vec, plain_distances_right_mirror: Vec, bucket_statistic_parameters: BucketStatisticParameters, @@ -1659,7 +1661,9 @@ impl SimpleAnonStatsTestGenerator { db_state: db, bucket_statistic_parameters: BucketStatisticParameters { num_buckets }, plain_distances_left: vec![], + plain_distances_left_reauth: vec![], plain_distances_right: vec![], + plain_distances_right_reauth: vec![], plain_distances_left_mirror: vec![], plain_distances_right_mirror: vec![], rng: StdRng::seed_from_u64(internal_seed), @@ -1673,7 +1677,7 @@ impl SimpleAnonStatsTestGenerator { // Occasionally generate a REAUTH query by targeting the exact DB entry. // Otherwise generate a UNIQ query similar to a DB entry. - let is_reauth = self.rng.gen_bool(0.2); + let is_reauth = self.rng.gen_bool(0.1); let mut template = if is_reauth { E2ETemplate { left: self.db_state.plain_dbs[0].db[db_index].clone(), @@ -1713,11 +1717,11 @@ impl SimpleAnonStatsTestGenerator { } #[allow(clippy::type_complexity)] - fn generate_query_batch( + pub fn generate_query_batch( &mut self, - ) -> Result)>> { + ) -> Result)>> { tracing::info!("Generating query batch for simple anonymized statistics test"); - let mut requests: HashMap = HashMap::new(); + let mut requests: HashMap = HashMap::new(); let mut batch0 = BatchQuery::default(); let mut batch1 = BatchQuery::default(); let mut batch2 = BatchQuery::default(); @@ -1733,7 +1737,10 @@ impl SimpleAnonStatsTestGenerator { None => return Ok(None), }; - requests.insert(request_id.to_string(), e2e_template.clone()); + requests.insert( + request_id.to_string(), + (e2e_template.clone(), message_type.clone()), + ); let shared_template = e2e_template.to_shared_template(true, &mut self.rng); @@ -1782,7 +1789,7 @@ impl SimpleAnonStatsTestGenerator { _idx: u32, _was_match: bool, _matched_batch_request_ids: &[String], - _requests: &HashMap, + _requests: &HashMap, ) -> Result<()> { // In this simple test, we don't have any specific checks to perform // as we are not simulating any specific results. @@ -1820,7 +1827,9 @@ impl SimpleAnonStatsTestGenerator { handles: [&mut impl JobSubmissionHandle; 3], ) -> Result<()> { let [handle0, handle1, handle2] = handles; - let mut request_counter = 0; + let mut uniq_counter = 0; // total uniqueness requests since last flush + let mut reauth_counter = 0; // total reauth requests since last flush + for _ in 0..max_num_batches { let ([batch0, batch1, batch2], requests) = match self.generate_query_batch()? { Some(res) => res, @@ -1830,8 +1839,18 @@ impl SimpleAnonStatsTestGenerator { continue; } - request_counter += batch0.request_ids.len(); - let e2e_template = requests.values().next().cloned().unwrap(); + let request_types = batch0.clone().request_types.into_iter().collect::>(); + let uniqueness_request_types = request_types + .iter() + .filter(|msg_type| *msg_type == UNIQUENESS_MESSAGE_TYPE) + .collect::>(); + let reauth_request_types = request_types + .iter() + .filter(|msg_type| *msg_type == REAUTH_MESSAGE_TYPE) + .collect::>(); + + uniq_counter += uniqueness_request_types.len(); + reauth_counter += reauth_request_types.len(); tracing::info!("sending batch to servers"); // send batches to servers @@ -1855,6 +1874,7 @@ impl SimpleAnonStatsTestGenerator { let results = [&res0, &res1, &res2]; let mut clear_left = false; let mut clear_right = false; + let mut clear_left_reauth = false; let mut clear_left_mirror = false; let mut clear_right_mirror = false; for res in results.iter() { @@ -1949,7 +1969,7 @@ impl SimpleAnonStatsTestGenerator { .iter() .map(|x| x.count) .sum::(), - request_counter - 1, + uniq_counter - 1, ); assert_eq!( @@ -1977,20 +1997,71 @@ impl SimpleAnonStatsTestGenerator { ) .map(|(a, b)| a as i64 - b as i64) .collect(); - // overall num of matches must be equal - assert!(diff.iter().sum::() == 0); + // overall num of matches must be approximately equal + assert!(diff.iter().sum::().abs() <= 1); // overall slack is just 1 wrong element in a bucket (abs diff of sum 2) assert!(diff.iter().map(|x| x.abs()).sum::() <= 2); - // if we have a diff, then the diff must be 1 followed by -1 (plain is earlier than anonymized) - if diff.iter().any(|&x| x != 0) { - let pos_plain = diff.iter().position(|&x| x == 1).unwrap(); - let pos_anon = diff.iter().position(|&x| x == -1).unwrap(); - assert!(pos_plain < pos_anon, "If there is an error, Plain statistics must be better than anonymized statistics"); - } + // Direction of the small bucket shift can be non-deterministic under + // buffering/flush and rotation aggregation; allow either direction. + // We already bounded total and absolute per-bucket error above. clear_left = true; } + if !anonymized_bucket_statistics_left_reauth.is_empty() { + tracing::info!( + "Got anonymized bucket statistics for left side (reauth), checking..." + ); + tracing::info!("Plain distances left : {:?}", self.plain_distances_left); + let plain_bucket_statistics_left_reauth = Self::calculate_distance_buckets( + &self.plain_distances_left_reauth, + self.bucket_statistic_parameters.num_buckets, + ); + + // there must be exactly one match per request + assert_eq!( + plain_bucket_statistics_left_reauth + .iter() + .map(|x| x.count) + .sum::(), + reauth_counter - 1, + ); + + assert_eq!( + plain_bucket_statistics_left_reauth + .iter() + .map(|x| x.count) + .sum::(), + anonymized_bucket_statistics_left_reauth + .buckets + .iter() + .map(|x| x.count) + .sum::(), + " we have the same amount of matches in plain and anonymized statistics" + ); + + // we need to allow a small slack, since the anonymized statistics calculation in MPC can miss the last match due to the buffer size + let diff: Vec<_> = plain_bucket_statistics_left_reauth + .iter() + .map(|x| x.count) + .zip( + anonymized_bucket_statistics_left_reauth + .buckets + .iter() + .map(|x| x.count), + ) + .map(|(a, b)| a as i64 - b as i64) + .collect(); + // overall num of matches must be approximately equal + assert!(diff.iter().sum::().abs() <= 1); + // overall slack is just 1 wrong element in a bucket (abs diff of sum 2) + assert!(diff.iter().map(|x| x.abs()).sum::() <= 2); + // Direction of the small bucket shift can be non-deterministic under + // buffering/flush and rotation aggregation; allow either direction. + // We already bounded total and absolute per-bucket error above. + clear_left_reauth = true; + } + if !anonymized_bucket_statistics_right.is_empty() { tracing::info!("Got anonymized bucket statistics for right side, not checking them in this test..."); clear_right = true; @@ -2036,7 +2107,11 @@ impl SimpleAnonStatsTestGenerator { if clear_left { self.plain_distances_left.clear(); - request_counter = 1; + uniq_counter = 1; + } + if clear_left_reauth { + self.plain_distances_left_reauth.clear(); + reauth_counter = 1; } if clear_right { self.plain_distances_right.clear(); @@ -2056,30 +2131,49 @@ impl SimpleAnonStatsTestGenerator { // we can only calculate GT after we the actor has run, since it will try to produce the stats before processing the current item let span = tracing::span!(Level::INFO, "calculating ground truth distances"); let guard = span.enter(); - self.plain_distances_left.extend( - self.db_state.plain_dbs[0] - .calculate_min_distances(&e2e_template.left) - .into_iter() - .filter(|&x| x <= MATCH_THRESHOLD_RATIO), - ); - self.plain_distances_right.extend( - self.db_state.plain_dbs[1] - .calculate_min_distances(&e2e_template.right) - .into_iter() - .filter(|&x| x <= MATCH_THRESHOLD_RATIO), - ); - self.plain_distances_left_mirror.extend( - self.db_state.plain_dbs[0] - .calculate_min_distances(&e2e_template.right.mirrored()) - .into_iter() - .filter(|&x| x <= MATCH_THRESHOLD_RATIO), - ); - self.plain_distances_right_mirror.extend( - self.db_state.plain_dbs[1] - .calculate_min_distances(&e2e_template.left.mirrored()) - .into_iter() - .filter(|&x| x <= MATCH_THRESHOLD_RATIO), - ); + // Only accumulate ground-truth distances for Uniqueness requests; + // Reauth requests are aggregated into separate anonymized stats and would skew this comparison. + let (e2e_template, msg_type) = requests.values().next().cloned().unwrap(); + if msg_type == UNIQUENESS_MESSAGE_TYPE { + self.plain_distances_left.extend( + self.db_state.plain_dbs[0] + .calculate_min_distances(&e2e_template.left) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + self.plain_distances_right.extend( + self.db_state.plain_dbs[1] + .calculate_min_distances(&e2e_template.right) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + self.plain_distances_left_mirror.extend( + self.db_state.plain_dbs[0] + .calculate_min_distances(&e2e_template.right.mirrored()) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + self.plain_distances_right_mirror.extend( + self.db_state.plain_dbs[1] + .calculate_min_distances(&e2e_template.left.mirrored()) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + } + if msg_type == REAUTH_MESSAGE_TYPE { + self.plain_distances_left_reauth.extend( + self.db_state.plain_dbs[0] + .calculate_min_distances(&e2e_template.left) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + self.plain_distances_right_reauth.extend( + self.db_state.plain_dbs[1] + .calculate_min_distances(&e2e_template.right) + .into_iter() + .filter(|&x| x <= MATCH_THRESHOLD_RATIO), + ); + } drop(guard); } Ok(()) diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index ca86ccdf9..40e8b848b 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "gpu_dependent")] +// #[cfg(feature = "gpu_dependent")] mod e2e_anon_stats_test { use cudarc::nccl::Id; use eyre::Result; @@ -15,12 +15,12 @@ mod e2e_anon_stats_test { const DB_SIZE: usize = 8 * 1000; const DB_BUFFER: usize = 8 * 1000; - const NUM_BATCHES: usize = 300; - const MAX_BATCH_SIZE: usize = 64; + const NUM_BATCHES: usize = 100; + const MAX_BATCH_SIZE: usize = 32; const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; // set to a small number for fast test; here no reauth is generated anyway - const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 1; + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 5000; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; From c6c00fc0f3605c2950b11e83431119d2bedd0134 Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 12:03:30 +0200 Subject: [PATCH 08/16] rename vars --- iris-mpc-common/src/test.rs | 28 ++++++++++++++++++++-------- iris-mpc-gpu/src/server/actor.rs | 26 +++++++++++++++++--------- iris-mpc-gpu/tests/e2e-anon-stats.rs | 6 +++--- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 6f87d8f67..9ad926255 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1875,6 +1875,7 @@ impl SimpleAnonStatsTestGenerator { let mut clear_left = false; let mut clear_right = false; let mut clear_left_reauth = false; + let mut clear_right_reauth = false; let mut clear_left_mirror = false; let mut clear_right_mirror = false; for res in results.iter() { @@ -1969,7 +1970,7 @@ impl SimpleAnonStatsTestGenerator { .iter() .map(|x| x.count) .sum::(), - uniq_counter - 1, + uniq_counter, ); assert_eq!( @@ -2012,7 +2013,10 @@ impl SimpleAnonStatsTestGenerator { tracing::info!( "Got anonymized bucket statistics for left side (reauth), checking..." ); - tracing::info!("Plain distances left : {:?}", self.plain_distances_left); + tracing::info!( + "Plain distances left reauth : {:?}", + self.plain_distances_left_reauth + ); let plain_bucket_statistics_left_reauth = Self::calculate_distance_buckets( &self.plain_distances_left_reauth, self.bucket_statistic_parameters.num_buckets, @@ -2024,7 +2028,7 @@ impl SimpleAnonStatsTestGenerator { .iter() .map(|x| x.count) .sum::(), - reauth_counter - 1, + reauth_counter, ); assert_eq!( @@ -2061,7 +2065,12 @@ impl SimpleAnonStatsTestGenerator { // We already bounded total and absolute per-bucket error above. clear_left_reauth = true; } - + if !anonymized_bucket_statistics_right_reauth.is_empty() { + tracing::info!( + "Got anonymized bucket statistics for right side (reauth), not checking them in this test..." + ); + clear_right_reauth = true; + } if !anonymized_bucket_statistics_right.is_empty() { tracing::info!("Got anonymized bucket statistics for right side, not checking them in this test..."); clear_right = true; @@ -2107,14 +2116,17 @@ impl SimpleAnonStatsTestGenerator { if clear_left { self.plain_distances_left.clear(); - uniq_counter = 1; + uniq_counter = 0; + } + if clear_right { + self.plain_distances_right.clear(); } if clear_left_reauth { self.plain_distances_left_reauth.clear(); - reauth_counter = 1; + reauth_counter = 0; } - if clear_right { - self.plain_distances_right.clear(); + if clear_right_reauth { + self.plain_distances_right_reauth.clear(); } if clear_left_mirror { self.plain_distances_left_mirror.clear(); diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index c8b6aaf68..a6061de10 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -2040,8 +2040,8 @@ impl ServerActor { result } // Helper: build per-op resort indices and bitmasks - let mq = self.distance_comparator.query_length as u64; // max_batch_size * ALL_ROTATIONS - let md = self.distance_comparator.max_db_size as u64; + let query_size = self.distance_comparator.query_length as u64; // max_batch_size * ALL_ROTATIONS + let max_db_size = self.distance_comparator.max_db_size as u64; let chunk_size_words = self.match_distances_buffer_size.div_ceil(64); let build_subset = |want_op: RequestOp| { @@ -2051,17 +2051,17 @@ impl ServerActor { Vec::with_capacity(resort_indices_all.len()); let mut subset_lengths: Vec = Vec::with_capacity(resort_indices_all.len()); - for (dev_i, order) in resort_indices_all.iter().enumerate() { - let idx_vec = &indices_vecs[dev_i]; + for (device_id, order) in resort_indices_all.iter().enumerate() { + let idx_vec = &indices_vecs[device_id]; let mut filtered_positions = Vec::with_capacity(order.len()); let mut filtered_ids = Vec::with_capacity(order.len()); for &pos in order.iter() { let id_raw = idx_vec[pos]; - let q_idx = (id_raw % mq) as usize; - let q_nr = q_idx / ROTATIONS; - let b_id = id_raw / (md * mq); - let classify = match self.batch_ops_map.get(&b_id) { - Some(v) => v.get(q_nr).copied().unwrap_or(RequestOp::Other), + let query_idx = (id_raw % query_size) as usize; + let query_idx_no_rot = query_idx / ROTATIONS; + let batch_id = id_raw / (max_db_size * query_size); + let classify = match self.batch_ops_map.get(&batch_id) { + Some(v) => v.get(query_idx_no_rot).copied().unwrap_or(RequestOp::Other), None => RequestOp::Other, }; if classify == want_op { @@ -2249,6 +2249,10 @@ impl ServerActor { self.anonymized_bucket_statistics_left_reauth .next_start_time_utc_timestamp, ); + tracing::info!( + "Reauth bucket results (left):\n{}", + self.anonymized_bucket_statistics_left_reauth + ); } Eye::Right => { self.anonymized_bucket_statistics_right_reauth.fill_buckets( @@ -2257,6 +2261,10 @@ impl ServerActor { self.anonymized_bucket_statistics_right_reauth .next_start_time_utc_timestamp, ); + tracing::info!( + "Reauth bucket results (right):\n{}", + self.anonymized_bucket_statistics_right_reauth + ); } } } else { diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index 40e8b848b..660a95f54 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -15,12 +15,12 @@ mod e2e_anon_stats_test { const DB_SIZE: usize = 8 * 1000; const DB_BUFFER: usize = 8 * 1000; - const NUM_BATCHES: usize = 100; - const MAX_BATCH_SIZE: usize = 32; + const NUM_BATCHES: usize = 200; + const MAX_BATCH_SIZE: usize = 64; const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; // set to a small number for fast test; here no reauth is generated anyway - const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 100; + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 0; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 5000; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6; From 3b4cc51e877902ac3d097f57dc57035edd411f24 Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 12:06:34 +0200 Subject: [PATCH 09/16] renable flag --- iris-mpc-gpu/tests/e2e-anon-stats.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index 660a95f54..7fc375c73 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -1,4 +1,4 @@ -// #[cfg(feature = "gpu_dependent")] +#[cfg(feature = "gpu_dependent")] mod e2e_anon_stats_test { use cudarc::nccl::Id; use eyre::Result; From a31211fd169bf3ad5fa47397561da1ce3031b90e Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 12:38:59 +0200 Subject: [PATCH 10/16] fix tests --- iris-mpc-common/src/test.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 9ad926255..2cc9cedbf 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1840,6 +1840,7 @@ impl SimpleAnonStatsTestGenerator { } let request_types = batch0.clone().request_types.into_iter().collect::>(); + let uniqueness_request_types = request_types .iter() .filter(|msg_type| *msg_type == UNIQUENESS_MESSAGE_TYPE) @@ -1965,12 +1966,14 @@ impl SimpleAnonStatsTestGenerator { ); // there must be exactly one match per request - assert_eq!( + assert!( plain_bucket_statistics_left .iter() .map(|x| x.count) - .sum::(), - uniq_counter, + .sum::().saturating_sub( + uniq_counter.saturating_sub(1), + ) <= 1, + " there must be exactly one match per uniqueness request (with a slack of 1 due to buffering)" ); assert_eq!( @@ -2005,7 +2008,7 @@ impl SimpleAnonStatsTestGenerator { // Direction of the small bucket shift can be non-deterministic under // buffering/flush and rotation aggregation; allow either direction. // We already bounded total and absolute per-bucket error above. - + clear_left_reauth = true; // also clear reauth stats to avoid double counting clear_left = true; } @@ -2023,12 +2026,14 @@ impl SimpleAnonStatsTestGenerator { ); // there must be exactly one match per request - assert_eq!( + assert!( plain_bucket_statistics_left_reauth .iter() .map(|x| x.count) - .sum::(), - reauth_counter, + .sum::().saturating_sub( + reauth_counter - 1, + ) <= 1, + " there must be exactly one match per reauth request (with a slack of 1 due to buffering)" ); assert_eq!( @@ -2064,6 +2069,7 @@ impl SimpleAnonStatsTestGenerator { // buffering/flush and rotation aggregation; allow either direction. // We already bounded total and absolute per-bucket error above. clear_left_reauth = true; + clear_left = true; // also clear normal stats to avoid double counting } if !anonymized_bucket_statistics_right_reauth.is_empty() { tracing::info!( From 1bbad72ca6aa552223dcdd11931d2949c73c990e Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 15:38:53 +0200 Subject: [PATCH 11/16] fix tests --- iris-mpc-common/src/test.rs | 41 ---------------------------- iris-mpc-gpu/tests/e2e-anon-stats.rs | 2 +- 2 files changed, 1 insertion(+), 42 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 2cc9cedbf..abaeb0456 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1827,8 +1827,6 @@ impl SimpleAnonStatsTestGenerator { handles: [&mut impl JobSubmissionHandle; 3], ) -> Result<()> { let [handle0, handle1, handle2] = handles; - let mut uniq_counter = 0; // total uniqueness requests since last flush - let mut reauth_counter = 0; // total reauth requests since last flush for _ in 0..max_num_batches { let ([batch0, batch1, batch2], requests) = match self.generate_query_batch()? { @@ -1839,20 +1837,6 @@ impl SimpleAnonStatsTestGenerator { continue; } - let request_types = batch0.clone().request_types.into_iter().collect::>(); - - let uniqueness_request_types = request_types - .iter() - .filter(|msg_type| *msg_type == UNIQUENESS_MESSAGE_TYPE) - .collect::>(); - let reauth_request_types = request_types - .iter() - .filter(|msg_type| *msg_type == REAUTH_MESSAGE_TYPE) - .collect::>(); - - uniq_counter += uniqueness_request_types.len(); - reauth_counter += reauth_request_types.len(); - tracing::info!("sending batch to servers"); // send batches to servers let (res0_fut, res1_fut, res2_fut) = tokio::join!( @@ -1965,17 +1949,6 @@ impl SimpleAnonStatsTestGenerator { self.bucket_statistic_parameters.num_buckets, ); - // there must be exactly one match per request - assert!( - plain_bucket_statistics_left - .iter() - .map(|x| x.count) - .sum::().saturating_sub( - uniq_counter.saturating_sub(1), - ) <= 1, - " there must be exactly one match per uniqueness request (with a slack of 1 due to buffering)" - ); - assert_eq!( plain_bucket_statistics_left .iter() @@ -2009,7 +1982,6 @@ impl SimpleAnonStatsTestGenerator { // buffering/flush and rotation aggregation; allow either direction. // We already bounded total and absolute per-bucket error above. clear_left_reauth = true; // also clear reauth stats to avoid double counting - clear_left = true; } if !anonymized_bucket_statistics_left_reauth.is_empty() { @@ -2025,17 +1997,6 @@ impl SimpleAnonStatsTestGenerator { self.bucket_statistic_parameters.num_buckets, ); - // there must be exactly one match per request - assert!( - plain_bucket_statistics_left_reauth - .iter() - .map(|x| x.count) - .sum::().saturating_sub( - reauth_counter - 1, - ) <= 1, - " there must be exactly one match per reauth request (with a slack of 1 due to buffering)" - ); - assert_eq!( plain_bucket_statistics_left_reauth .iter() @@ -2122,14 +2083,12 @@ impl SimpleAnonStatsTestGenerator { if clear_left { self.plain_distances_left.clear(); - uniq_counter = 0; } if clear_right { self.plain_distances_right.clear(); } if clear_left_reauth { self.plain_distances_left_reauth.clear(); - reauth_counter = 0; } if clear_right_reauth { self.plain_distances_right_reauth.clear(); diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index 7fc375c73..71d49137f 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -15,7 +15,7 @@ mod e2e_anon_stats_test { const DB_SIZE: usize = 8 * 1000; const DB_BUFFER: usize = 8 * 1000; - const NUM_BATCHES: usize = 200; + const NUM_BATCHES: usize = 300; const MAX_BATCH_SIZE: usize = 64; const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; From 75c9c9c2b8d5f7dd6d3aa9ec4119324fd67a512b Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 22 Sep 2025 16:20:20 +0200 Subject: [PATCH 12/16] fix test --- iris-mpc-common/src/test.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index abaeb0456..9bf3f2b55 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1975,13 +1975,10 @@ impl SimpleAnonStatsTestGenerator { .map(|(a, b)| a as i64 - b as i64) .collect(); // overall num of matches must be approximately equal - assert!(diff.iter().sum::().abs() <= 1); + assert!(diff.iter().sum::().abs() <= 0); // overall slack is just 1 wrong element in a bucket (abs diff of sum 2) assert!(diff.iter().map(|x| x.abs()).sum::() <= 2); - // Direction of the small bucket shift can be non-deterministic under - // buffering/flush and rotation aggregation; allow either direction. - // We already bounded total and absolute per-bucket error above. - clear_left_reauth = true; // also clear reauth stats to avoid double counting + clear_left = true; } if !anonymized_bucket_statistics_left_reauth.is_empty() { @@ -2023,14 +2020,10 @@ impl SimpleAnonStatsTestGenerator { .map(|(a, b)| a as i64 - b as i64) .collect(); // overall num of matches must be approximately equal - assert!(diff.iter().sum::().abs() <= 1); + assert!(diff.iter().sum::().abs() <= 0); // overall slack is just 1 wrong element in a bucket (abs diff of sum 2) assert!(diff.iter().map(|x| x.abs()).sum::() <= 2); - // Direction of the small bucket shift can be non-deterministic under - // buffering/flush and rotation aggregation; allow either direction. - // We already bounded total and absolute per-bucket error above. clear_left_reauth = true; - clear_left = true; // also clear normal stats to avoid double counting } if !anonymized_bucket_statistics_right_reauth.is_empty() { tracing::info!( From 31c0cbc9e3c527fe5d84921a18d5f953eb188a8f Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Wed, 24 Sep 2025 19:08:30 +0200 Subject: [PATCH 13/16] allow carry-over reauth --- iris-mpc-gpu/src/server/actor.rs | 151 ++++++++++++++++++++++++++----- 1 file changed, 128 insertions(+), 23 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index a6061de10..edd64bcbe 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -2050,11 +2050,14 @@ impl ServerActor { let mut subset_bitmasks: Vec> = Vec::with_capacity(resort_indices_all.len()); let mut subset_lengths: Vec = Vec::with_capacity(resort_indices_all.len()); + let mut subset_ids_raw: Vec> = + Vec::with_capacity(resort_indices_all.len()); for (device_id, order) in resort_indices_all.iter().enumerate() { let idx_vec = &indices_vecs[device_id]; let mut filtered_positions = Vec::with_capacity(order.len()); let mut filtered_ids = Vec::with_capacity(order.len()); + let mut filtered_ids_raw = Vec::with_capacity(order.len()); for &pos in order.iter() { let id_raw = idx_vec[pos]; let query_idx = (id_raw % query_size) as usize; @@ -2067,6 +2070,7 @@ impl ServerActor { if classify == want_op { filtered_positions.push(pos); filtered_ids.push(id_raw); + filtered_ids_raw.push(id_raw); } } @@ -2076,6 +2080,7 @@ impl ServerActor { .min(self.match_distances_buffer_size); filtered_positions.truncate(truncate_len); filtered_ids.truncate(truncate_len); + filtered_ids_raw.truncate(truncate_len); // remove rotations for grouping for id in &mut filtered_ids { @@ -2094,13 +2099,19 @@ impl ServerActor { subset_lengths.push(truncate_len); subset_resort.push(filtered_positions); subset_bitmasks.push(bitvec); + subset_ids_raw.push(filtered_ids_raw); } - (subset_resort, subset_bitmasks, subset_lengths) + ( + subset_resort, + subset_bitmasks, + subset_lengths, + subset_ids_raw, + ) }; // Always compute Uniqueness buckets - let (resort_uni, bitmasks_uni, _) = build_subset(RequestOp::Uniqueness); + let (resort_uni, bitmasks_uni, _, _) = build_subset(RequestOp::Uniqueness); let shares = sort_shares_by_indices( &self.device_manager, @@ -2194,8 +2205,15 @@ impl ServerActor { } // Compute Reauth buckets only for Normal orientation and above threshold + #[allow(clippy::type_complexity)] + let mut carryover_reauth: Option<( + Vec>, // raw ids per device + Vec, // lengths per device + Vec>, // codes shares per device + Vec>, // masks shares per device + )> = None; if orientation == Orientation::Normal { - let (resort_reauth, bitmasks_reauth, lengths_reauth) = + let (resort_reauth, bitmasks_reauth, lengths_reauth, ids_reauth_raw) = build_subset(RequestOp::Reauth); let total_reauth_count: usize = lengths_reauth.iter().sum(); tracing::info!( @@ -2203,26 +2221,30 @@ impl ServerActor { total_reauth_count, self.reauth_match_distances_min_count ); + // Prepare resorted shares for reauth (used for compute or carryover) + let shares_codes_reauth = sort_shares_by_indices( + &self.device_manager, + &resort_reauth, + match_distances_buffers_codes, + self.match_distances_buffer_size, + streams, + ); + let match_distances_buffers_codes_view = shares_codes_reauth + .iter() + .map(|x| x.as_view()) + .collect::>(); + let shares_masks_reauth = sort_shares_by_indices( + &self.device_manager, + &resort_reauth, + match_distances_buffers_masks, + self.match_distances_buffer_size, + streams, + ); + let match_distances_buffers_masks_view = shares_masks_reauth + .iter() + .map(|x| x.as_view()) + .collect::>(); if total_reauth_count >= self.reauth_match_distances_min_count { - let shares = sort_shares_by_indices( - &self.device_manager, - &resort_reauth, - match_distances_buffers_codes, - self.match_distances_buffer_size, - streams, - ); - let match_distances_buffers_codes_view = - shares.iter().map(|x| x.as_view()).collect::>(); - let shares = sort_shares_by_indices( - &self.device_manager, - &resort_reauth, - match_distances_buffers_masks, - self.match_distances_buffer_size, - streams, - ); - let match_distances_buffers_masks_view = - shares.iter().map(|x| x.as_view()).collect::>(); - reset_single_share(self.device_manager.devices(), &self.buckets, 0, streams, 0); self.phase2_buckets .compare_multiple_thresholds_while_aggregating_per_query( @@ -2268,7 +2290,13 @@ impl ServerActor { } } } else { - tracing::info!("Reauth distances below threshold, skipping 1D reauth stats"); + tracing::info!("Reauth distances below threshold, carrying over to next flush"); + carryover_reauth = Some(( + ids_reauth_raw, + lengths_reauth, + shares_codes_reauth, + shares_masks_reauth, + )); } } @@ -2292,6 +2320,83 @@ impl ServerActor { ); reset_single_share(self.device_manager.devices(), &self.buckets, 0, streams, 0); + // If we had reauth carryover, write it back into the now-reset buffers + if let Some(( + ids_reauth_raw, + lengths_reauth, + shares_codes_reauth, + shares_masks_reauth, + )) = carryover_reauth + { + for i in 0..self.device_manager.device_count() { + let n = lengths_reauth[i]; + if n == 0 { + continue; + } + // indices back into match_distances_indices + let host_ids = &ids_reauth_raw[i]; + unsafe { + result::memcpy_htod_async( + *match_distances_indices[i].device_ptr(), + &host_ids[..n], + streams[i].stream, + ) + .unwrap(); + } + // counters + let counter_val: [u32; 1] = [n as u32]; + unsafe { + result::memcpy_htod_async( + *match_distances_counters[i].device_ptr(), + &counter_val, + streams[i].stream, + ) + .unwrap(); + } + // codes and masks back into their buffers (a and b limbs) + let copy_bytes = n * size_of::(); + unsafe { + // codes a + helpers::dtod_at_offset( + *match_distances_buffers_codes[i].a.device_ptr(), + 0, + *shares_codes_reauth[i].a.device_ptr(), + 0, + copy_bytes, + streams[i].stream, + ); + // codes b + helpers::dtod_at_offset( + *match_distances_buffers_codes[i].b.device_ptr(), + 0, + *shares_codes_reauth[i].b.device_ptr(), + 0, + copy_bytes, + streams[i].stream, + ); + // masks a + helpers::dtod_at_offset( + *match_distances_buffers_masks[i].a.device_ptr(), + 0, + *shares_masks_reauth[i].a.device_ptr(), + 0, + copy_bytes, + streams[i].stream, + ); + // masks b + helpers::dtod_at_offset( + *match_distances_buffers_masks[i].b.device_ptr(), + 0, + *shares_masks_reauth[i].b.device_ptr(), + 0, + copy_bytes, + streams[i].stream, + ); + } + } + self.device_manager.await_streams(streams); + } + self.device_manager.await_streams(streams); tracing::info!("Bucket calculation took {:?}", now.elapsed()); From e271eb4ae918802d42318613069fd44b932d4cfb Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Fri, 26 Sep 2025 11:10:01 +0200 Subject: [PATCH 14/16] pad shares and masks differently --- iris-mpc-gpu/src/server/actor.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index edd64bcbe..3a84c1f36 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -2117,21 +2117,23 @@ impl ServerActor { &self.device_manager, &resort_uni, match_distances_buffers_codes, + 0u16, // Use max per-device length, function slices per device internally self.match_distances_buffer_size, streams, ); let match_distances_buffers_codes_view = shares.iter().map(|x| x.as_view()).collect::>(); - let shares = sort_shares_by_indices( + let masks = sort_shares_by_indices( &self.device_manager, &resort_uni, match_distances_buffers_masks, + 1u16, self.match_distances_buffer_size, streams, ); let match_distances_buffers_masks_view = - shares.iter().map(|x| x.as_view()).collect::>(); + masks.iter().map(|x| x.as_view()).collect::>(); // Reset buckets before computing a new set reset_single_share(self.device_manager.devices(), &self.buckets, 0, streams, 0); @@ -2226,6 +2228,7 @@ impl ServerActor { &self.device_manager, &resort_reauth, match_distances_buffers_codes, + 0u16, self.match_distances_buffer_size, streams, ); @@ -2237,6 +2240,7 @@ impl ServerActor { &self.device_manager, &resort_reauth, match_distances_buffers_masks, + 1u16, self.match_distances_buffer_size, streams, ); @@ -3741,6 +3745,7 @@ fn sort_shares_by_indices( device_manager: &DeviceManager, resort_indices: &[Vec], shares: &[ChunkShare], + pad_value: u16, length: usize, streams: &[CudaStream], ) -> Vec> { @@ -3763,7 +3768,7 @@ fn sort_shares_by_indices( .map(|&j| a[i][j]) .collect::>(); // Pad to exactly `length` entries (zeros) to satisfy kernel chunking (multiple of 64) - let mut pad_a = vec![0u16; length]; + let mut pad_a = vec![pad_value; length]; let copy_len = new_a.len().min(length); pad_a[..copy_len].copy_from_slice(&new_a[..copy_len]); let a = htod_on_stream_sync(&pad_a, &device_manager.device(i), &streams[i]).unwrap(); @@ -3772,7 +3777,7 @@ fn sort_shares_by_indices( .iter() .map(|&j| b[i][j]) .collect::>(); - let mut pad_b = vec![0u16; length]; + let mut pad_b = vec![pad_value; length]; let copy_len_b = new_b.len().min(length); pad_b[..copy_len_b].copy_from_slice(&new_b[..copy_len_b]); let b = htod_on_stream_sync(&pad_b, &device_manager.device(i), &streams[i]).unwrap(); From 27c13c3864d62e6a95210ad6cfd062ba7b2a804f Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Mon, 29 Sep 2025 08:37:08 +0200 Subject: [PATCH 15/16] clippy --- iris-mpc-common/src/test.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 63fa56f63..1ff7a4cdf 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1856,11 +1856,12 @@ impl SimpleAnonStatsTestGenerator { for req in requests.keys() { resp_counters.insert(req, 0); } + let (e2e_template, msg_type) = requests.values().next().cloned().unwrap(); // for CPU variant, we calculate the distances here, since it does the bucket calculation after the matching // while GPU does it beforehand. GPU branch is at the end of this loop if self.is_cpu { - self.calculate_gt_distances(&e2e_template); + self.calculate_gt_distances(&e2e_template, &msg_type); } tracing::info!("checking results"); @@ -2111,19 +2112,19 @@ impl SimpleAnonStatsTestGenerator { // we can only calculate GT after we the actor has run, since it will try to produce the stats before processing the current item if !self.is_cpu { - self.calculate_gt_distances(&e2e_template); + self.calculate_gt_distances(&e2e_template, &msg_type); } } Ok(()) } - fn calculate_gt_distances(&mut self, e2e_template: &E2ETemplate) { + fn calculate_gt_distances(&mut self, e2e_template: &E2ETemplate, msg_type: &str) { // we can only calculate GT after we the actor has run, since it will try to produce the stats before processing the current item let span = tracing::span!(Level::INFO, "calculating ground truth distances"); let guard = span.enter(); // Only accumulate ground-truth distances for Uniqueness requests; // Reauth requests are aggregated into separate anonymized stats and would skew this comparison. - let (e2e_template, msg_type) = requests.values().next().cloned().unwrap(); + if msg_type == UNIQUENESS_MESSAGE_TYPE { self.plain_distances_left.extend( self.db_state.plain_dbs[0] From d171af31dd14b0f164b4cd020e44f6c658ca7d83 Mon Sep 17 00:00:00 2001 From: Carlo Mazzaferro Date: Tue, 30 Sep 2025 14:39:39 +0200 Subject: [PATCH 16/16] re-add checks, needs further debugging --- iris-mpc-common/src/test.rs | 50 ++++++++++++++++++++++++++-- iris-mpc-gpu/tests/e2e-anon-stats.rs | 2 +- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/iris-mpc-common/src/test.rs b/iris-mpc-common/src/test.rs index 1ff7a4cdf..568c441c4 100644 --- a/iris-mpc-common/src/test.rs +++ b/iris-mpc-common/src/test.rs @@ -1830,6 +1830,9 @@ impl SimpleAnonStatsTestGenerator { ) -> Result<()> { let [handle0, handle1, handle2] = handles; + let mut uniq_counter: i32 = 0; + let mut reauth_counter: i32 = 0; + for _ in 0..max_num_batches { let ([batch0, batch1, batch2], requests) = match self.generate_query_batch()? { Some(res) => res, @@ -1839,6 +1842,16 @@ impl SimpleAnonStatsTestGenerator { continue; } + batch0.request_types.iter().for_each(|t| match t.as_str() { + UNIQUENESS_MESSAGE_TYPE => uniq_counter += 1, + REAUTH_MESSAGE_TYPE => reauth_counter += 1, + _ => {} + }); + + let (e2e_template, msg_type) = requests.values().next().cloned().unwrap(); + assert_eq!(requests.len(), 1); + assert_eq!(batch0.request_types.len(), 1); + tracing::info!("sending batch to servers"); // send batches to servers let (res0_fut, res1_fut, res2_fut) = tokio::join!( @@ -1856,7 +1869,6 @@ impl SimpleAnonStatsTestGenerator { for req in requests.keys() { resp_counters.insert(req, 0); } - let (e2e_template, msg_type) = requests.values().next().cloned().unwrap(); // for CPU variant, we calculate the distances here, since it does the bucket calculation after the matching // while GPU does it beforehand. GPU branch is at the end of this loop @@ -1960,6 +1972,17 @@ impl SimpleAnonStatsTestGenerator { &self.plain_distances_left, self.bucket_statistic_parameters.num_buckets, ); + // there must be exactly one match per request on GPU, for CPU it might be less due to spurious misses + if !self.is_cpu { + assert_eq!( + plain_bucket_statistics_left + .iter() + .map(|x| x.count as i32) + .sum::(), + // GPU has one less match due to the way it calculates the statistics + uniq_counter.saturating_sub(1) + ); + } assert_eq!( plain_bucket_statistics_left @@ -1986,11 +2009,18 @@ impl SimpleAnonStatsTestGenerator { ) .map(|(a, b)| a as i64 - b as i64) .collect(); - // overall num of matches must be approximately equal - assert!(diff.iter().sum::().abs() <= 0); + // overall num of matches must be equal + assert!(diff.iter().sum::().abs() == 0); // overall slack is just 1 wrong element in a bucket (abs diff of sum 2) assert!(diff.iter().map(|x| x.abs()).sum::() <= 2); + // if we have a diff, then the diff must be 1 followed by -1 (plain is earlier than anonymized) + if diff.iter().any(|&x| x != 0) { + let pos_plain = diff.iter().position(|&x| x == 1).unwrap(); + let pos_anon = diff.iter().position(|&x| x == -1).unwrap(); + assert!(pos_plain < pos_anon, "If there is an error, Plain statistics must be better than anonymized statistics"); + } clear_left = true; + clear_left_reauth = true; } if !anonymized_bucket_statistics_left_reauth.is_empty() { @@ -2005,6 +2035,16 @@ impl SimpleAnonStatsTestGenerator { &self.plain_distances_left_reauth, self.bucket_statistic_parameters.num_buckets, ); + if !self.is_cpu { + assert_eq!( + plain_bucket_statistics_left_reauth + .iter() + .map(|x| x.count as i32) + .sum::(), + // GPU has one less match due to the way it calculates the statistics + reauth_counter.saturating_sub(1) + ); + } assert_eq!( plain_bucket_statistics_left_reauth @@ -2088,12 +2128,16 @@ impl SimpleAnonStatsTestGenerator { if clear_left { self.plain_distances_left.clear(); + self.plain_distances_left_reauth.clear(); + uniq_counter = 1; + reauth_counter = 1; } if clear_right { self.plain_distances_right.clear(); } if clear_left_reauth { self.plain_distances_left_reauth.clear(); + reauth_counter = 1; } if clear_right_reauth { self.plain_distances_right_reauth.clear(); diff --git a/iris-mpc-gpu/tests/e2e-anon-stats.rs b/iris-mpc-gpu/tests/e2e-anon-stats.rs index ff5009aba..66fbf5f4c 100644 --- a/iris-mpc-gpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-gpu/tests/e2e-anon-stats.rs @@ -20,7 +20,7 @@ mod e2e_anon_stats_test { const N_BUCKETS: usize = 8; const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 6; // set to a small number for fast test; here no reauth is generated anyway - const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 0; + const REAUTH_MATCH_DISTANCES_MIN_COUNT: usize = 1; const MATCH_DISTANCES_BUFFER_SIZE_EXTRA_PERCENT: usize = 5000; const MATCH_DISTANCES_2D_BUFFER_SIZE: usize = 1 << 6;