Skip to content

Commit 2944b08

Browse files
authored
(POP 2812) Different scheduler for search (#1606)
add different scheduler for search
1 parent 5eeea8f commit 2944b08

File tree

3 files changed

+125
-14
lines changed

3 files changed

+125
-14
lines changed

iris-mpc-cpu/src/execution/hawk_main/intra_batch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub async fn intra_batch_is_match(
2929
assert_eq!(n_requests, search_queries[RIGHT].len());
3030
let n_rotations = search_queries[LEFT].first().map(|r| r.len()).unwrap_or(1);
3131

32-
let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches();
32+
let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches();
3333

3434
let (tx, rx) = unbounded_channel::<IsMatch>();
3535

iris-mpc-cpu/src/execution/hawk_main/scheduler.rs

Lines changed: 123 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ impl Schedule {
5454

5555
/// Enumerate all combinations of eye sides, requests, and rotations.
5656
/// Distribute the tasks over a number of sessions.
57-
pub fn batches(&self) -> Vec<Batch> {
57+
/// Note: Should be used exclusively for intra_match_is_batch
58+
/// as it is optimized for its logic
59+
pub fn intra_match_batches(&self) -> Vec<Batch> {
5860
let n_tasks = self.n_requests * self.n_rotations;
5961
let batch_size = n_tasks / self.n_sessions;
6062
let rest_size = n_tasks % self.n_sessions;
@@ -86,6 +88,45 @@ impl Schedule {
8688
.collect_vec()
8789
}
8890

91+
/// Enumerate all combinations of eye sides, requests, and rotations.
92+
/// Distribute the tasks over a number of sessions.
93+
/// This method is search-aware and weighs central rotations as higher workloads than non-central ones
94+
pub fn search_batches(&self) -> Vec<Batch> {
95+
let n_tasks = self.n_requests * self.n_rotations;
96+
let batch_size = n_tasks / self.n_sessions;
97+
let rest_size = n_tasks % self.n_sessions;
98+
99+
(0..N_EYES)
100+
.flat_map(|i_eye| {
101+
// Iterate requests first and rotations second (contrast with intra_match_batches)
102+
// This ensures that heavier center tasks are grouped with their (lighter) rotations
103+
// This order also opens up the possibility of better memory access patterns, assuming rotations
104+
// of a fixed iris behave similarly
105+
let mut task_iter = (0..self.n_requests).flat_map(move |i_request| {
106+
(0..self.n_rotations).map(move |i_rotation| Task {
107+
i_eye,
108+
i_request,
109+
i_rotation,
110+
is_central: (i_rotation == self.n_rotations / 2),
111+
})
112+
});
113+
114+
(0..self.n_sessions).map(move |i_session| {
115+
// Some sessions get one more task if n_sessions does not divide n_tasks.
116+
let one_more = (i_session < rest_size) as usize;
117+
118+
let tasks = task_iter.by_ref().take(batch_size + one_more).collect_vec();
119+
120+
Batch {
121+
i_eye,
122+
i_session,
123+
tasks,
124+
}
125+
})
126+
})
127+
.collect_vec()
128+
}
129+
89130
pub fn organize_results<T, ROT: Rotations>(
90131
&self,
91132
mut results: HashMap<TaskId, T>,
@@ -154,25 +195,25 @@ mod test {
154195
use iris_mpc_common::ROTATIONS;
155196

156197
#[test]
157-
fn test_schedule() {
198+
fn test_intra_match_schedule() {
158199
for n_rotations in [1, ROTATIONS] {
159-
test_schedule_impl(1, 0, n_rotations);
160-
test_schedule_impl(1, 1, n_rotations);
161-
test_schedule_impl(1, 2, n_rotations);
162-
test_schedule_impl(10, 1, n_rotations);
163-
test_schedule_impl(1, 10, n_rotations);
164-
test_schedule_impl(7, 10, n_rotations);
165-
test_schedule_impl(10, 30, n_rotations);
166-
test_schedule_impl(10, 97, n_rotations);
200+
test_intra_match_schedule_impl(1, 0, n_rotations);
201+
test_intra_match_schedule_impl(1, 1, n_rotations);
202+
test_intra_match_schedule_impl(1, 2, n_rotations);
203+
test_intra_match_schedule_impl(10, 1, n_rotations);
204+
test_intra_match_schedule_impl(1, 10, n_rotations);
205+
test_intra_match_schedule_impl(7, 10, n_rotations);
206+
test_intra_match_schedule_impl(10, 30, n_rotations);
207+
test_intra_match_schedule_impl(10, 97, n_rotations);
167208
}
168209
}
169210

170-
fn test_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
211+
fn test_intra_match_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
171212
let n_eyes = N_EYES;
172213
let n_batches = n_eyes * n_sessions;
173214
let n_tasks = n_eyes * n_requests * n_rotations;
174215

175-
let batches = Schedule::new(n_sessions, n_requests, n_rotations).batches();
216+
let batches = Schedule::new(n_sessions, n_requests, n_rotations).intra_match_batches();
176217
assert_eq!(batches.len(), n_batches);
177218

178219
let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum();
@@ -203,6 +244,76 @@ mod test {
203244
assert_eq!(unique_tasks, n_tasks);
204245
}
205246

247+
#[test]
248+
fn test_search_schedule() {
249+
for n_rotations in [1, ROTATIONS] {
250+
test_search_schedule_impl(1, 0, n_rotations);
251+
test_search_schedule_impl(1, 1, n_rotations);
252+
test_search_schedule_impl(1, 2, n_rotations);
253+
test_search_schedule_impl(10, 1, n_rotations);
254+
test_search_schedule_impl(1, 10, n_rotations);
255+
test_search_schedule_impl(7, 10, n_rotations);
256+
test_search_schedule_impl(10, 30, n_rotations);
257+
test_search_schedule_impl(10, 97, n_rotations);
258+
}
259+
}
260+
261+
fn test_search_schedule_impl(n_sessions: usize, n_requests: usize, n_rotations: usize) {
262+
let n_eyes = N_EYES;
263+
let n_batches = n_eyes * n_sessions;
264+
let n_tasks = n_eyes * n_requests * n_rotations;
265+
266+
let batches = Schedule::new(n_sessions, n_requests, n_rotations).search_batches();
267+
assert_eq!(batches.len(), n_batches);
268+
269+
let count_tasks: usize = batches.iter().map(|b| b.tasks.len()).sum();
270+
assert_eq!(count_tasks, n_tasks);
271+
272+
let unique_sessions = batches
273+
.iter()
274+
.map(|b| (b.i_eye, b.i_session))
275+
.unique()
276+
.count();
277+
assert_eq!(unique_sessions, n_batches);
278+
279+
let unique_tasks = batches
280+
.iter()
281+
.flat_map(|b| {
282+
assert!(b.i_eye < n_eyes);
283+
assert!(b.i_session < n_sessions);
284+
285+
b.tasks.iter().map(|t| {
286+
assert!(t.i_request < n_requests);
287+
assert!(t.i_rotation < n_rotations);
288+
289+
(b.i_eye, t.i_request, t.i_rotation)
290+
})
291+
})
292+
.unique()
293+
.count();
294+
assert_eq!(unique_tasks, n_tasks);
295+
296+
// Check central-rotation load balance
297+
let minmax = batches
298+
.iter()
299+
.map(|batch| {
300+
batch
301+
.tasks
302+
.iter()
303+
.map(|task| task.is_central as usize)
304+
.sum::<usize>()
305+
})
306+
.minmax();
307+
let dif = match minmax {
308+
itertools::MinMaxResult::NoElements => 0,
309+
itertools::MinMaxResult::OneElement(_) => 0,
310+
itertools::MinMaxResult::MinMax(min, max) => max - min,
311+
};
312+
// The difference might be 2 in contrived cases, but it's
313+
// not worth addressing
314+
assert!(dif <= 1);
315+
}
316+
206317
#[test]
207318
fn test_range_forward_backward() {
208319
assert!(range_forward_backward(0).collect_vec().is_empty());

iris-mpc-cpu/src/execution/hawk_main/search.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ where
6666

6767
let schedule = Schedule::new(n_sessions, n_requests, ROT::N_ROTATIONS);
6868

69-
parallelize(schedule.batches().into_iter().map(per_session)).await?;
69+
parallelize(schedule.search_batches().into_iter().map(per_session)).await?;
7070

7171
let results = schedule.organize_results(collect_results(rx).await?)?;
7272

0 commit comments

Comments
 (0)