Skip to content

Commit ba9d5fc

Browse files
reuse events (#861)
* reuse events * assign context * dbg * init events once * trigger image build * deploy test image to stage --------- Co-authored-by: Ertugrul Aypek <ertugrul.aypek@toolsforhumanity.com>
1 parent 6d5b2f2 commit ba9d5fc

File tree

4 files changed

+28
-35
lines changed

4 files changed

+28
-35
lines changed

.github/workflows/temp-branch-build-and-push.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name: Branch - Build and push docker image
33
on:
44
push:
55
branches:
6-
- "ps/host-mem-alloc"
6+
- "ps/reuse-events"
77

88
concurrency:
99
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'

deploy/stage/common-values-iris-mpc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
image: "ghcr.io/worldcoin/iris-mpc:6b358589c25ef528f58ba02980103670b037a614"
1+
image: "ghcr.io/worldcoin/iris-mpc:v0.13.6"
22

33
environment: stage
44
replicaCount: 1

iris-mpc-gpu/src/helpers/device_manager.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ impl DeviceManager {
103103
}
104104

105105
pub fn destroy_events(&self, events: Vec<CUevent>) {
106-
for event in events {
107-
unsafe { event::destroy(event).unwrap() };
106+
for (device_idx, event) in events.iter().enumerate() {
107+
self.device(device_idx).bind_to_thread().unwrap();
108+
unsafe { event::destroy(*event).unwrap() };
108109
}
109110
}
110111

iris-mpc-gpu/src/server/actor.rs

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ pub struct ServerActor {
107107
disable_persistence: bool,
108108
code_chunk_buffers: Vec<DBChunkBuffers>,
109109
mask_chunk_buffers: Vec<DBChunkBuffers>,
110+
dot_events: Vec<Vec<CUevent>>,
111+
exchange_events: Vec<Vec<CUevent>>,
112+
phase2_events: Vec<Vec<CUevent>>,
110113
}
111114

112115
const NON_MATCH_ID: u32 = u32::MAX;
@@ -330,6 +333,11 @@ impl ServerActor {
330333
let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
331334
let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
332335

336+
// Create all needed events
337+
let dot_events = vec![device_manager.create_events(); 2];
338+
let exchange_events = vec![device_manager.create_events(); 2];
339+
let phase2_events = vec![device_manager.create_events(); 2];
340+
333341
for dev in device_manager.devices() {
334342
dev.synchronize().unwrap();
335343
}
@@ -367,6 +375,9 @@ impl ServerActor {
367375
disable_persistence,
368376
code_chunk_buffers,
369377
mask_chunk_buffers,
378+
dot_events,
379+
exchange_events,
380+
phase2_events,
370381
})
371382
}
372383

@@ -1126,14 +1137,6 @@ impl ServerActor {
11261137
tracing::info!(party_id = self.party_id, "Finished batch deduplication");
11271138
// ---- END BATCH DEDUP ----
11281139

1129-
// Create new initial events
1130-
let mut current_dot_event = self.device_manager.create_events();
1131-
let mut next_dot_event = self.device_manager.create_events();
1132-
let mut current_exchange_event = self.device_manager.create_events();
1133-
let mut next_exchange_event = self.device_manager.create_events();
1134-
let mut current_phase2_event = self.device_manager.create_events();
1135-
let mut next_phase2_event = self.device_manager.create_events();
1136-
11371140
let chunk_sizes = |chunk_idx: usize| {
11381141
self.current_db_sizes
11391142
.iter()
@@ -1195,11 +1198,11 @@ impl ServerActor {
11951198
// First stream doesn't need to wait
11961199
if db_chunk_idx == 0 {
11971200
self.device_manager
1198-
.record_event(request_streams, &current_dot_event);
1201+
.record_event(request_streams, &self.dot_events[db_chunk_idx % 2]);
11991202
self.device_manager
1200-
.record_event(request_streams, &current_exchange_event);
1203+
.record_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);
12011204
self.device_manager
1202-
.record_event(request_streams, &current_phase2_event);
1205+
.record_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);
12031206
}
12041207

12051208
// Prefetch next chunk
@@ -1229,7 +1232,7 @@ impl ServerActor {
12291232
);
12301233

12311234
self.device_manager
1232-
.await_event(request_streams, &current_dot_event);
1235+
.await_event(request_streams, &self.dot_events[db_chunk_idx % 2]);
12331236

12341237
// ---- START PHASE 1 ----
12351238
record_stream_time!(&self.device_manager, batch_streams, events, "db_dot", {
@@ -1247,7 +1250,7 @@ impl ServerActor {
12471250

12481251
// wait for the exchange result buffers to be ready
12491252
self.device_manager
1250-
.await_event(request_streams, &current_exchange_event);
1253+
.await_event(request_streams, &self.exchange_events[db_chunk_idx % 2]);
12511254

12521255
record_stream_time!(
12531256
&self.device_manager,
@@ -1268,7 +1271,7 @@ impl ServerActor {
12681271
);
12691272

12701273
self.device_manager
1271-
.record_event(request_streams, &next_dot_event);
1274+
.record_event(request_streams, &self.dot_events[(db_chunk_idx + 1) % 2]);
12721275

12731276
record_stream_time!(
12741277
&self.device_manager,
@@ -1286,7 +1289,7 @@ impl ServerActor {
12861289
// ---- END PHASE 1 ----
12871290

12881291
self.device_manager
1289-
.await_event(request_streams, &current_phase2_event);
1292+
.await_event(request_streams, &self.phase2_events[db_chunk_idx % 2]);
12901293

12911294
// ---- START PHASE 2 ----
12921295
let max_chunk_size = dot_chunk_size.iter().max().copied().unwrap();
@@ -1318,8 +1321,10 @@ impl ServerActor {
13181321
// we can now record the exchange event since the phase 2 is no longer using the
13191322
// code_dots/mask_dots which are just reinterpretations of the exchange result
13201323
// buffers
1321-
self.device_manager
1322-
.record_event(request_streams, &next_exchange_event);
1324+
self.device_manager.record_event(
1325+
request_streams,
1326+
&self.exchange_events[(db_chunk_idx + 1) % 2],
1327+
);
13231328

13241329
let res = self.phase2.take_result_buffer();
13251330
record_stream_time!(&self.device_manager, request_streams, events, "db_open", {
@@ -1340,23 +1345,10 @@ impl ServerActor {
13401345
});
13411346
}
13421347
self.device_manager
1343-
.record_event(request_streams, &next_phase2_event);
1348+
.record_event(request_streams, &self.phase2_events[(db_chunk_idx + 1) % 2]);
13441349

13451350
// ---- END PHASE 2 ----
13461351

1347-
// Destroy events
1348-
self.device_manager.destroy_events(current_dot_event);
1349-
self.device_manager.destroy_events(current_exchange_event);
1350-
self.device_manager.destroy_events(current_phase2_event);
1351-
1352-
// Update events for synchronization
1353-
current_dot_event = next_dot_event;
1354-
current_exchange_event = next_exchange_event;
1355-
current_phase2_event = next_phase2_event;
1356-
next_dot_event = self.device_manager.create_events();
1357-
next_exchange_event = self.device_manager.create_events();
1358-
next_phase2_event = self.device_manager.create_events();
1359-
13601352
// Increment chunk index
13611353
db_chunk_idx += 1;
13621354

0 commit comments

Comments
 (0)