Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub async fn intra_batch_is_match(
assert_eq!(n_requests, search_queries[RIGHT].len());
let n_rotations = search_queries[LEFT].first().map(|r| r.len()).unwrap_or(1);

let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches();
let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches();

let (tx, rx) = unbounded_channel::<IsMatch>();

Expand Down
135 changes: 123 additions & 12 deletions iris-mpc-cpu/src/execution/hawk_main/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ impl Schedule {

/// Enumerate all combinations of eye sides, requests, and rotations.
/// Distribute the tasks over a number of sessions.
pub fn batches(&self) -> Vec<Batch> {
/// Note: Should be used exclusively for intra_match_is_batch
/// as it is optimized for its logic
pub fn intra_match_batches(&self) -> Vec<Batch> {
let n_tasks = self.n_requests * self.n_rotations;
let batch_size = n_tasks / self.n_sessions;
let rest_size = n_tasks % self.n_sessions;
Expand Down Expand Up @@ -86,6 +88,45 @@ impl Schedule {
.collect_vec()
}

/// Enumerate all combinations of eye sides, requests, and rotations.
/// Distribute the tasks over a number of sessions.
/// This method is search-aware and weighs central rotations as higher workloads than non-central ones
pub fn search_batches(&self) -> Vec<Batch> {
let n_tasks = self.n_requests * self.n_rotations;
let batch_size = n_tasks / self.n_sessions;
let rest_size = n_tasks % self.n_sessions;

(0..N_EYES)
.flat_map(|i_eye| {
// Iterate requests first and rotations second (contrast with intra_match_batches)
// This ensures that heavier center tasks are grouped with their (lighter) rotations
// This order also opens up the possibility of better memory access patterns, assuming rotations
// of a fixed iris behave similarly
let mut task_iter = (0..self.n_requests).flat_map(move |i_request| {
(0..self.n_rotations).map(move |i_rotation| Task {
i_eye,
i_request,
i_rotation,
is_central: (i_rotation == self.n_rotations / 2),
})
});

(0..self.n_sessions).map(move |i_session| {
// Some sessions get one more task if n_sessions does not divide n_tasks.
let one_more = (i_session < rest_size) as usize;

let tasks = task_iter.by_ref().take(batch_size + one_more).collect_vec();

Batch {
i_eye,
i_session,
tasks,
}
})
})
.collect_vec()
}

pub fn organize_results<T, ROT: Rotations>(
&self,
mut results: HashMap<TaskId, T>,
Expand Down Expand Up @@ -154,25 +195,25 @@ mod test {
use iris_mpc_common::ROTATIONS;

#[test]
fn test_schedule() {
fn test_intra_match_schedule() {
for n_rotations in [1, ROTATIONS] {
test_schedule_impl(1, 0, n_rotations);
test_schedule_impl(1, 1, n_rotations);
test_schedule_impl(1, 2, n_rotations);
test_schedule_impl(10, 1, n_rotations);
test_schedule_impl(1, 10, n_rotations);
test_schedule_impl(7, 10, n_rotations);
test_schedule_impl(10, 30, n_rotations);
test_schedule_impl(10, 97, n_rotations);
test_intra_match_schedule_impl(1, 0, n_rotations);
test_intra_match_schedule_impl(1, 1, n_rotations);
test_intra_match_schedule_impl(1, 2, n_rotations);
test_intra_match_schedule_impl(10, 1, n_rotations);
test_intra_match_schedule_impl(1, 10, n_rotations);
test_intra_match_schedule_impl(7, 10, n_rotations);
test_intra_match_schedule_impl(10, 30, n_rotations);
test_intra_match_schedule_impl(10, 97, n_rotations);
}
}

fn test_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
fn test_intra_match_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
let n_eyes = N_EYES;
let n_batches = n_eyes * n_sessions;
let n_tasks = n_eyes * n_requests * n_rotations;

let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches();
let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches();
assert_eq!(batches.len(), n_batches);

let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum();
Expand Down Expand Up @@ -203,6 +244,76 @@ mod test {
assert_eq!(unique_tasks, n_tasks);
}

#[test]
fn test_search_schedule() {
for n_rotations in [1, ROTATIONS] {
test_search_schedule_impl(1, 0, n_rotations);
test_search_schedule_impl(1, 1, n_rotations);
test_search_schedule_impl(1, 2, n_rotations);
test_search_schedule_impl(10, 1, n_rotations);
test_search_schedule_impl(1, 10, n_rotations);
test_search_schedule_impl(7, 10, n_rotations);
test_search_schedule_impl(10, 30, n_rotations);
test_search_schedule_impl(10, 97, n_rotations);
}
}

fn test_search_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
let n_eyes = N_EYES;
let n_batches = n_eyes * n_sessions;
let n_tasks = n_eyes * n_requests * n_rotations;

let batches = Schedule::new(n_sessions, n_requests, n_rotations).search_batches();
assert_eq!(batches.len(), n_batches);

let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum();
assert_eq!(count_tasks, n_tasks);

let unique_sessions = batches
.iter()
.map(|b| (b.i_eye, b.i_session))
.unique()
.count();
assert_eq!(unique_sessions, n_batches);

let unique_tasks = batches
.iter()
.flat_map(|b| {
assert!(b.i_eye < n_eyes);
assert!(b.i_session < n_sessions);

b.tasks.iter().map(|t| {
assert!(t.i_request < n_requests);
assert!(t.i_rotation < n_rotations);

(b.i_eye, t.i_request, t.i_rotation)
})
})
.unique()
.count();
assert_eq!(unique_tasks, n_tasks);

// Check central-rotation load balance
let minmax = batches
.iter()
.map(|batch| {
batch
.tasks
.iter()
.map(|task| task.is_central as usize)
.sum::<usize>()
})
.minmax();
let dif = match minmax {
itertools::MinMaxResult::NoElements => 0,
itertools::MinMaxResult::OneElement(_) => 0,
itertools::MinMaxResult::MinMax(min, max) => max - min,
};
// The difference might be 2 in contrived cases, but it's
// not worth addressing
assert!(dif <= 1);
}

#[test]
fn test_range_forward_backward() {
assert!(range_forward_backward(0).collect_vec().is_empty());
Expand Down
2 changes: 1 addition & 1 deletion iris-mpc-cpu/src/execution/hawk_main/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ where

let schedule = Schedule::new(n_sessions, n_requests, ROT::N_ROTATIONS);

parallelize(schedule.batches().into_iter().map(per_session)).await?;
parallelize(schedule.search_batches().into_iter().map(per_session)).await?;

let results = schedule.organize_results(collect_results(rx).await?)?;

Expand Down
Loading