Skip to content

Commit 303d90e

Browse files
committed
Merge branch 'dev' into au/sw/graceful_shutdown3
2 parents 31a3780 + 9844546 commit 303d90e

File tree

14 files changed

+227
-99
lines changed

14 files changed

+227
-99
lines changed

deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ env:
5555
value: "30"
5656

5757
- name: SMPC__PROCESSING_TIMEOUT_SECS
58-
value: "1240"
58+
value: "1800"
5959

6060
- name: SMPC__HAWK_REQUEST_PARALLELISM
6161
value: "32"

deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ env:
5555
value: "30"
5656

5757
- name: SMPC__PROCESSING_TIMEOUT_SECS
58-
value: "1240"
58+
value: "1800"
5959

6060
- name: SMPC__HAWK_REQUEST_PARALLELISM
6161
value: "32"

deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ env:
5555
value: "30"
5656

5757
- name: SMPC__PROCESSING_TIMEOUT_SECS
58-
value: "1240"
58+
value: "1800"
5959

6060
- name: SMPC__HAWK_REQUEST_PARALLELISM
6161
value: "32"

deploy/dev/common-values-ampc-hnsw.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
image: "ghcr.io/worldcoin/iris-mpc-cpu:d3a394f65b2f99a76111441373077b1c6dcf254f"
1+
image: "ghcr.io/worldcoin/iris-mpc-cpu:04dc776bc0a4c046e9ecbaad797e8605328cb121"
22

33
environment: dev
44
replicaCount: 1

iris-mpc-common/src/helpers/shutdown_handler.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ impl ShutdownHandler {
3636
self.ct.cancelled().await
3737
}
3838

39+
pub fn get_cancellation_token(&self) -> CancellationToken {
40+
self.ct.clone()
41+
}
42+
3943
pub async fn register_signal_handler(&self) {
4044
let ct = self.ct.clone();
4145
tokio::spawn(async move {

iris-mpc-cpu/src/execution/hawk_main.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use tokio::{
6767
join,
6868
sync::{mpsc, oneshot, RwLock, RwLockWriteGuard},
6969
};
70+
use tokio_util::sync::CancellationToken;
7071

7172
pub type GraphStore = graph_store::GraphPg<Aby3Store>;
7273
pub type GraphTx<'a> = graph_store::GraphTx<'a, Aby3Store>;
@@ -292,9 +293,10 @@ impl HawkInsertPlan {
292293
}
293294

294295
impl HawkActor {
295-
pub async fn from_cli(args: &HawkArgs) -> Result<Self> {
296+
pub async fn from_cli(args: &HawkArgs, ct: CancellationToken) -> Result<Self> {
296297
Self::from_cli_with_graph_and_store(
297298
args,
299+
ct,
298300
[(); 2].map(|_| GraphMem::new()),
299301
[(); 2].map(|_| Aby3Store::new_storage(None)),
300302
)
@@ -303,6 +305,7 @@ impl HawkActor {
303305

304306
pub async fn from_cli_with_graph_and_store(
305307
args: &HawkArgs,
308+
ct: CancellationToken,
306309
graph: BothEyes<GraphMem<Aby3VectorRef>>,
307310
iris_store: BothEyes<Aby3SharedIrises>,
308311
) -> Result<Self> {
@@ -326,7 +329,8 @@ impl HawkActor {
326329
let my_index = args.party_index;
327330

328331
let networking =
329-
build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?;
332+
build_network_handle(args, ct, &identities, SessionGroups::N_SESSIONS_PER_REQUEST)
333+
.await?;
330334
let graph_store = graph.map(GraphMem::to_arc);
331335
let iris_store = iris_store.map(SharedIrises::to_arc);
332336
let workers_handle = [LEFT, RIGHT]
@@ -1653,7 +1657,7 @@ impl HawkHandle {
16531657

16541658
pub async fn hawk_main(args: HawkArgs) -> Result<HawkHandle> {
16551659
println!("🦅 Starting Hawk node {}", args.party_index);
1656-
let hawk_actor = HawkActor::from_cli(&args).await?;
1660+
let hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
16571661
HawkHandle::new(hawk_actor).await
16581662
}
16591663

@@ -2049,7 +2053,7 @@ mod tests_db {
20492053
disable_persistence: false,
20502054
tls: None,
20512055
};
2052-
let mut hawk_actor = HawkActor::from_cli(&args).await?;
2056+
let mut hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
20532057
let (_, graph_loader) = hawk_actor.as_iris_loader().await;
20542058
graph_loader.load_graph_store(&graph_store, 2).await?;
20552059

iris-mpc-cpu/src/execution/hawk_main/test_utils.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use itertools::Itertools;
1111
use rand::SeedableRng;
1212
use std::{sync::Arc, time::Duration};
1313
use tokio::time::sleep;
14+
use tokio_util::sync::CancellationToken;
1415

1516
use crate::{
1617
execution::local::get_free_local_addresses,
@@ -42,7 +43,7 @@ pub async fn setup_hawk_actors() -> Result<Vec<HawkActor>> {
4243
// Make the test async.
4344
sleep(Duration::from_millis(index as u64)).await;
4445

45-
HawkActor::from_cli(&args).await
46+
HawkActor::from_cli(&args, CancellationToken::new()).await
4647
}
4748
};
4849

iris-mpc-cpu/src/network/tcp/handle.rs

Lines changed: 144 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,23 @@ use async_trait::async_trait;
1313
use bytes::BytesMut;
1414
use eyre::Result;
1515
use iris_mpc_common::fast_metrics::FastHistogram;
16-
use std::io;
17-
use std::{collections::HashMap, time::Instant};
16+
use std::{
17+
collections::HashMap,
18+
io,
19+
sync::{
20+
atomic::{AtomicBool, Ordering},
21+
Arc,
22+
},
23+
time::Instant,
24+
};
1825
use tokio::{
1926
io::{AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
2027
sync::{
2128
mpsc::{self, error::TryRecvError, UnboundedReceiver, UnboundedSender},
22-
oneshot,
29+
oneshot, Mutex,
2330
},
2431
};
32+
use tokio_util::sync::CancellationToken;
2533

2634
const BUFFER_CAPACITY: usize = 32 * 1024;
2735
const READ_BUF_SIZE: usize = 2 * 1024 * 1024;
@@ -69,21 +77,26 @@ impl<T: NetworkConnection + 'static> TcpNetworkHandle<T> {
6977
reconnector: Reconnector<T>,
7078
connections: PeerConnections<T>,
7179
config: TcpConfig,
80+
ct: CancellationToken,
7281
) -> Self {
7382
let peers = connections.keys().cloned().collect();
7483
let mut ch_map = HashMap::new();
84+
let conn_state = ConnectionState::new();
7585
for (peer_id, connections) in connections {
7686
let mut m = HashMap::new();
7787
for (stream_id, connection) in connections {
7888
let rc = reconnector.clone();
7989
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
8090
m.insert(stream_id, cmd_tx);
8191

92+
let ct2 = ct.clone();
8293
tokio::spawn(manage_connection(
8394
connection,
8495
rc,
8596
config.get_sessions_for_stream(&stream_id),
8697
cmd_rx,
98+
ct2,
99+
conn_state.clone(),
87100
));
88101
}
89102
ch_map.insert(peer_id.clone(), m);
@@ -220,6 +233,8 @@ async fn manage_connection<T: NetworkConnection>(
220233
reconnector: Reconnector<T>,
221234
num_sessions: usize,
222235
mut cmd_ch: UnboundedReceiver<Cmd>,
236+
ct: CancellationToken,
237+
conn_state: ConnectionState,
223238
) {
224239
let Connection {
225240
peer,
@@ -257,35 +272,55 @@ async fn manage_connection<T: NetworkConnection>(
257272
// when new sessions are requested, also tear down and stand up the forwarders.
258273
loop {
259274
let r_writer = writer.as_mut().expect("writer should be Some");
260-
let r_outbound = &mut outbound_rx;
261-
let outbound_task = async move {
262-
let r = handle_outbound_traffic(r_writer, r_outbound, num_sessions).await;
263-
tracing::warn!("handle_outbound_traffic exited: {r:?}");
264-
};
275+
let outbound_task = handle_outbound_traffic(r_writer, &mut outbound_rx, num_sessions);
265276

266277
let r_reader = reader.as_mut().expect("reader should be Some");
267-
let r_inbound = &inbound_forwarder;
268-
let inbound_task = async move {
269-
let r = handle_inbound_traffic(r_reader, r_inbound).await;
270-
tracing::warn!("handle_inbound_traffic exited: {r:?}");
271-
};
278+
let inbound_task = handle_inbound_traffic(r_reader, &inbound_forwarder);
272279

273280
enum Evt {
274281
Cmd(Cmd),
275282
Disconnected,
283+
Shutdown,
276284
}
277285
let event = tokio::select! {
278286
maybe_cmd = cmd_ch.recv() => {
279287
match maybe_cmd {
280288
Some(cmd) => Evt::Cmd(cmd),
281289
None => {
282-
tracing::info!("cmd channel closed");
290+
// basically means the networking stack was shut down before a command was received.
291+
if conn_state.exited() {
292+
tracing::info!("cmd channel closed");
293+
};
294+
ct.cancel();
283295
return;
284296
}
285297
}
286298
}
287-
_ = inbound_task => Evt::Disconnected,
288-
_ = outbound_task => Evt::Disconnected,
299+
r = inbound_task => {
300+
if ct.is_cancelled() {
301+
Evt::Shutdown
302+
} else if let Err(e) = r {
303+
if conn_state.incr_reconnect().await {
304+
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
305+
}
306+
Evt::Disconnected
307+
} else {
308+
unreachable!();
309+
}
310+
},
311+
r = outbound_task => {
312+
if ct.is_cancelled() {
313+
Evt::Shutdown
314+
} else if let Err(e) = r {
315+
if conn_state.incr_reconnect().await {
316+
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
317+
}
318+
Evt::Disconnected
319+
} else {
320+
unreachable!();
321+
}
322+
},
323+
_ = ct.cancelled() => Evt::Shutdown,
289324
};
290325

291326
// update the Arcs depending on the event. wait for reconnect if needed.
@@ -314,21 +349,35 @@ async fn manage_connection<T: NetworkConnection>(
314349
)
315350
.await
316351
{
317-
tracing::error!("reconnect failed: {e:?}");
352+
tracing::debug!("reconnect failed: {e:?}");
318353
return;
319354
};
320355
rsp.send(Ok(())).unwrap();
321356
}
322357
},
323358
Evt::Disconnected => {
324-
tracing::info!("reconnecting to {:?}: {:?}", peer, stream_id);
359+
tracing::debug!("reconnecting to {:?}: {:?}", peer, stream_id);
325360
if let Err(e) =
326361
reconnect_and_replace(&reconnector, &peer, stream_id, &mut reader, &mut writer)
327362
.await
328363
{
329-
tracing::error!("reconnect failed: {e:?}");
364+
if conn_state.exited() {
365+
if ct.is_cancelled() {
366+
tracing::info!("shutting down TCP/TLS networking stack");
367+
} else {
368+
tracing::error!("reconnect failed: {e:?}");
369+
}
370+
}
330371
return;
331-
};
372+
} else if conn_state.decr_reconnect().await {
373+
tracing::info!("all connections re-established");
374+
}
375+
}
376+
Evt::Shutdown => {
377+
if conn_state.exited() {
378+
tracing::info!("shutting down TCP/TLS networking stack");
379+
}
380+
return;
332381
}
333382
}
334383
}
@@ -412,17 +461,11 @@ async fn handle_outbound_traffic<T: NetworkConnection>(
412461
}
413462
}
414463

415-
if let Err(e) = write_buf(stream, &mut buf).await {
416-
tracing::error!(error=%e, "Failed to flush buffer on outbound_rx");
417-
return Err(e);
418-
}
464+
write_buf(stream, &mut buf).await?
419465
}
420466

421467
if !buf.is_empty() {
422-
if let Err(e) = write_buf(stream, &mut buf).await {
423-
tracing::error!(error=%e, "Failed to flush buffer when outbound_rx closed");
424-
return Err(e);
425-
}
468+
write_buf(stream, &mut buf).await?
426469
}
427470
// the channel will not receive any more commands
428471
tracing::debug!("outbound_rx closed");
@@ -493,7 +536,10 @@ async fn handle_inbound_traffic<T: NetworkConnection>(
493536
}
494537
}
495538
Err(e) => {
496-
tracing::error!("failed to deserialize message: {e}");
539+
return Err(io::Error::new(
540+
io::ErrorKind::Other,
541+
format!("failed to deserialize message: {e}"),
542+
));
497543
}
498544
};
499545
} else {
@@ -515,3 +561,72 @@ async fn write_buf<T: NetworkConnection>(
515561
buf.clear();
516562
Ok(())
517563
}
564+
565+
// state which is shared by all connections. used to reduce the number of logs
566+
// emitted upon loss of network connectivity.
567+
#[derive(Clone)]
568+
struct ConnectionState {
569+
inner: Arc<ConnectionStateInner>,
570+
}
571+
572+
impl ConnectionState {
573+
fn new() -> Self {
574+
Self {
575+
inner: Arc::new(ConnectionStateInner::new()),
576+
}
577+
}
578+
579+
fn exited(&self) -> bool {
580+
self.inner
581+
.exited
582+
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
583+
.is_ok()
584+
}
585+
586+
async fn incr_reconnect(&self) -> bool {
587+
self.inner.reconnecting.increment().await
588+
}
589+
590+
async fn decr_reconnect(&self) -> bool {
591+
self.inner.reconnecting.decrement().await
592+
}
593+
}
594+
595+
struct ConnectionStateInner {
596+
reconnecting: Counter,
597+
exited: AtomicBool,
598+
}
599+
600+
impl ConnectionStateInner {
601+
fn new() -> Self {
602+
Self {
603+
reconnecting: Counter::new(),
604+
exited: AtomicBool::new(false),
605+
}
606+
}
607+
}
608+
609+
struct Counter {
610+
num: Mutex<usize>,
611+
}
612+
613+
impl Counter {
614+
fn new() -> Self {
615+
Self { num: Mutex::new(0) }
616+
}
617+
618+
// returns true if num was zero
619+
async fn increment(&self) -> bool {
620+
let mut l = self.num.lock().await;
621+
*l += 1;
622+
*l == 1
623+
}
624+
625+
// returns true if num was one before decrementing
626+
async fn decrement(&self) -> bool {
627+
let mut l = self.num.lock().await;
628+
let r = *l == 1;
629+
*l = l.saturating_sub(1);
630+
r
631+
}
632+
}

0 commit comments

Comments
 (0)