@@ -54,7 +54,9 @@ impl Schedule {
54
54
55
55
/// Enumerate all combinations of eye sides, requests, and rotations.
56
56
/// Distribute the tasks over a number of sessions.
57
- pub fn batches ( & self ) -> Vec < Batch > {
57
+ /// Note: Should be used exclusively for intra_match_is_batch
58
+ /// as it is optimized for its logic
59
+ pub fn intra_match_batches ( & self ) -> Vec < Batch > {
58
60
let n_tasks = self . n_requests * self . n_rotations ;
59
61
let batch_size = n_tasks / self . n_sessions ;
60
62
let rest_size = n_tasks % self . n_sessions ;
@@ -86,6 +88,45 @@ impl Schedule {
86
88
. collect_vec ( )
87
89
}
88
90
91
+ /// Enumerate all combinations of eye sides, requests, and rotations.
92
+ /// Distribute the tasks over a number of sessions.
93
+ /// This method is search-aware and weighs central rotations as higher workloads than non-central ones
94
+ pub fn search_batches ( & self ) -> Vec < Batch > {
95
+ let n_tasks = self . n_requests * self . n_rotations ;
96
+ let batch_size = n_tasks / self . n_sessions ;
97
+ let rest_size = n_tasks % self . n_sessions ;
98
+
99
+ ( 0 ..N_EYES )
100
+ . flat_map ( |i_eye| {
101
+ // Iterate requests first and rotations second (contrast with intra_match_batches)
102
+ // This ensures that heavier center tasks are grouped with their (lighter) rotations
103
+ // This order also opens up the possibility of better memory access patterns, assuming rotations
104
+ // of a fixed iris behave similarly
105
+ let mut task_iter = ( 0 ..self . n_requests ) . flat_map ( move |i_request| {
106
+ ( 0 ..self . n_rotations ) . map ( move |i_rotation| Task {
107
+ i_eye,
108
+ i_request,
109
+ i_rotation,
110
+ is_central : ( i_rotation == self . n_rotations / 2 ) ,
111
+ } )
112
+ } ) ;
113
+
114
+ ( 0 ..self . n_sessions ) . map ( move |i_session| {
115
+ // Some sessions get one more task if n_sessions does not divide n_tasks.
116
+ let one_more = ( i_session < rest_size) as usize ;
117
+
118
+ let tasks = task_iter. by_ref ( ) . take ( batch_size + one_more) . collect_vec ( ) ;
119
+
120
+ Batch {
121
+ i_eye,
122
+ i_session,
123
+ tasks,
124
+ }
125
+ } )
126
+ } )
127
+ . collect_vec ( )
128
+ }
129
+
89
130
pub fn organize_results < T , ROT : Rotations > (
90
131
& self ,
91
132
mut results : HashMap < TaskId , T > ,
@@ -154,25 +195,25 @@ mod test {
154
195
use iris_mpc_common:: ROTATIONS ;
155
196
156
197
#[ test]
157
- fn test_schedule ( ) {
198
+ fn test_intra_match_schedule ( ) {
158
199
for n_rotations in [ 1 , ROTATIONS ] {
159
- test_schedule_impl ( 1 , 0 , n_rotations) ;
160
- test_schedule_impl ( 1 , 1 , n_rotations) ;
161
- test_schedule_impl ( 1 , 2 , n_rotations) ;
162
- test_schedule_impl ( 10 , 1 , n_rotations) ;
163
- test_schedule_impl ( 1 , 10 , n_rotations) ;
164
- test_schedule_impl ( 7 , 10 , n_rotations) ;
165
- test_schedule_impl ( 10 , 30 , n_rotations) ;
166
- test_schedule_impl ( 10 , 97 , n_rotations) ;
200
+ test_intra_match_schedule_impl ( 1 , 0 , n_rotations) ;
201
+ test_intra_match_schedule_impl ( 1 , 1 , n_rotations) ;
202
+ test_intra_match_schedule_impl ( 1 , 2 , n_rotations) ;
203
+ test_intra_match_schedule_impl ( 10 , 1 , n_rotations) ;
204
+ test_intra_match_schedule_impl ( 1 , 10 , n_rotations) ;
205
+ test_intra_match_schedule_impl ( 7 , 10 , n_rotations) ;
206
+ test_intra_match_schedule_impl ( 10 , 30 , n_rotations) ;
207
+ test_intra_match_schedule_impl ( 10 , 97 , n_rotations) ;
167
208
}
168
209
}
169
210
170
- fn test_schedule_impl ( n_sessions : usize , n_requests : usize , n_rotations : usize ) {
211
+ fn test_intra_match_schedule_impl ( n_sessions : usize , n_requests : usize , n_rotations : usize ) {
171
212
let n_eyes = N_EYES ;
172
213
let n_batches = n_eyes * n_sessions;
173
214
let n_tasks = n_eyes * n_requests * n_rotations;
174
215
175
- let batches = Schedule :: new ( n_sessions, n_requests, n_rotations) . batches ( ) ;
216
+ let batches = Schedule :: new ( n_sessions, n_requests, n_rotations) . intra_match_batches ( ) ;
176
217
assert_eq ! ( batches. len( ) , n_batches) ;
177
218
178
219
let count_tasks: usize = batches. iter ( ) . map ( |b| b. tasks . len ( ) ) . sum ( ) ;
@@ -203,6 +244,76 @@ mod test {
203
244
assert_eq ! ( unique_tasks, n_tasks) ;
204
245
}
205
246
247
+ #[ test]
248
+ fn test_search_schedule ( ) {
249
+ for n_rotations in [ 1 , ROTATIONS ] {
250
+ test_search_schedule_impl ( 1 , 0 , n_rotations) ;
251
+ test_search_schedule_impl ( 1 , 1 , n_rotations) ;
252
+ test_search_schedule_impl ( 1 , 2 , n_rotations) ;
253
+ test_search_schedule_impl ( 10 , 1 , n_rotations) ;
254
+ test_search_schedule_impl ( 1 , 10 , n_rotations) ;
255
+ test_search_schedule_impl ( 7 , 10 , n_rotations) ;
256
+ test_search_schedule_impl ( 10 , 30 , n_rotations) ;
257
+ test_search_schedule_impl ( 10 , 97 , n_rotations) ;
258
+ }
259
+ }
260
+
261
+ fn test_search_schedule_impl ( n_sessions : usize , n_requests : usize , n_rotations : usize ) {
262
+ let n_eyes = N_EYES ;
263
+ let n_batches = n_eyes * n_sessions;
264
+ let n_tasks = n_eyes * n_requests * n_rotations;
265
+
266
+ let batches = Schedule :: new ( n_sessions, n_requests, n_rotations) . search_batches ( ) ;
267
+ assert_eq ! ( batches. len( ) , n_batches) ;
268
+
269
+ let count_tasks: usize = batches. iter ( ) . map ( |b| b. tasks . len ( ) ) . sum ( ) ;
270
+ assert_eq ! ( count_tasks, n_tasks) ;
271
+
272
+ let unique_sessions = batches
273
+ . iter ( )
274
+ . map ( |b| ( b. i_eye , b. i_session ) )
275
+ . unique ( )
276
+ . count ( ) ;
277
+ assert_eq ! ( unique_sessions, n_batches) ;
278
+
279
+ let unique_tasks = batches
280
+ . iter ( )
281
+ . flat_map ( |b| {
282
+ assert ! ( b. i_eye < n_eyes) ;
283
+ assert ! ( b. i_session < n_sessions) ;
284
+
285
+ b. tasks . iter ( ) . map ( |t| {
286
+ assert ! ( t. i_request < n_requests) ;
287
+ assert ! ( t. i_rotation < n_rotations) ;
288
+
289
+ ( b. i_eye , t. i_request , t. i_rotation )
290
+ } )
291
+ } )
292
+ . unique ( )
293
+ . count ( ) ;
294
+ assert_eq ! ( unique_tasks, n_tasks) ;
295
+
296
+ // Check central-rotation load balance
297
+ let minmax = batches
298
+ . iter ( )
299
+ . map ( |batch| {
300
+ batch
301
+ . tasks
302
+ . iter ( )
303
+ . map ( |task| task. is_central as usize )
304
+ . sum :: < usize > ( )
305
+ } )
306
+ . minmax ( ) ;
307
+ let dif = match minmax {
308
+ itertools:: MinMaxResult :: NoElements => 0 ,
309
+ itertools:: MinMaxResult :: OneElement ( _) => 0 ,
310
+ itertools:: MinMaxResult :: MinMax ( min, max) => max - min,
311
+ } ;
312
+ // The difference might be 2 in contrived cases, but it's
313
+ // not worth addressing
314
+ assert ! ( dif <= 1 ) ;
315
+ }
316
+
206
317
#[ test]
207
318
fn test_range_forward_backward ( ) {
208
319
assert ! ( range_forward_backward( 0 ) . collect_vec( ) . is_empty( ) ) ;
0 commit comments