From 4706746adf79514abd51627685290cce96973bf6 Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 27 Aug 2025 15:03:34 +0300 Subject: [PATCH] add different scheduler for search --- .../src/execution/hawk_main/intra_batch.rs | 2 +- .../src/execution/hawk_main/scheduler.rs | 135 ++++++++++++++++-- .../src/execution/hawk_main/search.rs | 2 +- 3 files changed, 125 insertions(+), 14 deletions(-) diff --git a/iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs b/iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs index 115e81e61..b04157737 100644 --- a/iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs +++ b/iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs @@ -29,7 +29,7 @@ pub async fn intra_batch_is_match( assert_eq!(n_requests, search_queries[RIGHT].len()); let n_rotations = search_queries[LEFT].first().map(|r| r.len()).unwrap_or(1); - let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches(); + let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches(); let (tx, rx) = unbounded_channel::(); diff --git a/iris-mpc-cpu/src/execution/hawk_main/scheduler.rs b/iris-mpc-cpu/src/execution/hawk_main/scheduler.rs index 68df7a98e..837333d47 100644 --- a/iris-mpc-cpu/src/execution/hawk_main/scheduler.rs +++ b/iris-mpc-cpu/src/execution/hawk_main/scheduler.rs @@ -54,7 +54,9 @@ impl Schedule { /// Enumerate all combinations of eye sides, requests, and rotations. /// Distribute the tasks over a number of sessions. - pub fn batches(&self) -> Vec { + /// Note: Should be used exclusively for intra_match_is_batch + /// as it is optimized for its logic + pub fn intra_match_batches(&self) -> Vec { let n_tasks = self.n_requests * self.n_rotations; let batch_size = n_tasks / self.n_sessions; let rest_size = n_tasks % self.n_sessions; @@ -86,6 +88,45 @@ impl Schedule { .collect_vec() } + /// Enumerate all combinations of eye sides, requests, and rotations. + /// Distribute the tasks over a number of sessions. + /// This method is search-aware and weighs central rotations as higher workloads than non-central ones + pub fn search_batches(&self) -> Vec { + let n_tasks = self.n_requests * self.n_rotations; + let batch_size = n_tasks / self.n_sessions; + let rest_size = n_tasks % self.n_sessions; + + (0..N_EYES) + .flat_map(|i_eye| { + // Iterate requests first and rotations second (contrast with intra_match_batches) + // This ensures that heavier center tasks are grouped with their (lighter) rotations + // This order also opens up the possibility of better memory access patterns, assuming rotations + // of a fixed iris behave similarly + let mut task_iter = (0..self.n_requests).flat_map(move |i_request| { + (0..self.n_rotations).map(move |i_rotation| Task { + i_eye, + i_request, + i_rotation, + is_central: (i_rotation == self.n_rotations / 2), + }) + }); + + (0..self.n_sessions).map(move |i_session| { + // Some sessions get one more task if n_sessions does not divide n_tasks. + let one_more = (i_session < rest_size) as usize; + + let tasks = task_iter.by_ref().take(batch_size + one_more).collect_vec(); + + Batch { + i_eye, + i_session, + tasks, + } + }) + }) + .collect_vec() + } + pub fn organize_results( &self, mut results: HashMap, @@ -154,25 +195,25 @@ mod test { use iris_mpc_common::ROTATIONS; #[test] - fn test_schedule() { + fn test_intra_match_schedule() { for n_rotations in [1, ROTATIONS] { - test_schedule_impl(1, 0, n_rotations); - test_schedule_impl(1, 1, n_rotations); - test_schedule_impl(1, 2, n_rotations); - test_schedule_impl(10, 1, n_rotations); - test_schedule_impl(1, 10, n_rotations); - test_schedule_impl(7, 10, n_rotations); - test_schedule_impl(10, 30, n_rotations); - test_schedule_impl(10, 97, n_rotations); + test_intra_match_schedule_impl(1, 0, n_rotations); + test_intra_match_schedule_impl(1, 1, n_rotations); + test_intra_match_schedule_impl(1, 2, n_rotations); + test_intra_match_schedule_impl(10, 1, n_rotations); + test_intra_match_schedule_impl(1, 10, n_rotations); + test_intra_match_schedule_impl(7, 10, n_rotations); + test_intra_match_schedule_impl(10, 30, n_rotations); + test_intra_match_schedule_impl(10, 97, n_rotations); } } - fn test_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) { + fn test_intra_match_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) { let n_eyes = N_EYES; let n_batches = n_eyes * n_sessions; let n_tasks = n_eyes * n_requests * n_rotations; - let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches(); + let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches(); assert_eq!(batches.len(), n_batches); let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum(); @@ -203,6 +244,76 @@ mod test { assert_eq!(unique_tasks, n_tasks); } + #[test] + fn test_search_schedule() { + for n_rotations in [1, ROTATIONS] { + test_search_schedule_impl(1, 0, n_rotations); + test_search_schedule_impl(1, 1, n_rotations); + test_search_schedule_impl(1, 2, n_rotations); + test_search_schedule_impl(10, 1, n_rotations); + test_search_schedule_impl(1, 10, n_rotations); + test_search_schedule_impl(7, 10, n_rotations); + test_search_schedule_impl(10, 30, n_rotations); + test_search_schedule_impl(10, 97, n_rotations); + } + } + + fn test_search_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) { + let n_eyes = N_EYES; + let n_batches = n_eyes * n_sessions; + let n_tasks = n_eyes * n_requests * n_rotations; + + let batches = Schedule::new(n_sessions, n_requests, n_rotations).search_batches(); + assert_eq!(batches.len(), n_batches); + + let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum(); + assert_eq!(count_tasks, n_tasks); + + let unique_sessions = batches + .iter() + .map(|b| (b.i_eye, b.i_session)) + .unique() + .count(); + assert_eq!(unique_sessions, n_batches); + + let unique_tasks = batches + .iter() + .flat_map(|b| { + assert!(b.i_eye < n_eyes); + assert!(b.i_session < n_sessions); + + b.tasks.iter().map(|t| { + assert!(t.i_request < n_requests); + assert!(t.i_rotation < n_rotations); + + (b.i_eye, t.i_request, t.i_rotation) + }) + }) + .unique() + .count(); + assert_eq!(unique_tasks, n_tasks); + + // Check central-rotation load balance + let minmax = batches + .iter() + .map(|batch| { + batch + .tasks + .iter() + .map(|task| task.is_central as usize) + .sum::() + }) + .minmax(); + let dif = match minmax { + itertools::MinMaxResult::NoElements => 0, + itertools::MinMaxResult::OneElement(_) => 0, + itertools::MinMaxResult::MinMax(min, max) => max - min, + }; + // The difference might be 2 in contrived cases, but it's + // not worth addressing + assert!(dif <= 1); + } + #[test] fn test_range_forward_backward() { assert!(range_forward_backward(0).collect_vec().is_empty()); diff --git a/iris-mpc-cpu/src/execution/hawk_main/search.rs b/iris-mpc-cpu/src/execution/hawk_main/search.rs index 428888d15..a2a2069c4 100644 --- a/iris-mpc-cpu/src/execution/hawk_main/search.rs +++ b/iris-mpc-cpu/src/execution/hawk_main/search.rs @@ -66,7 +66,7 @@ where let schedule = Schedule::new(n_sessions, n_requests, ROT::N_ROTATIONS); - parallelize(schedule.batches().into_iter().map(per_session)).await?; + parallelize(schedule.search_batches().into_iter().map(per_session)).await?; let results = schedule.organize_results(collect_results(rx).await?)?;