Skip to content
4 changes: 4 additions & 0 deletions iris-mpc-common/src/helpers/shutdown_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ impl ShutdownHandler {
self.ct.cancelled().await
}

pub fn get_cancellation_token(&self) -> CancellationToken {
self.ct.clone()
}

pub async fn register_signal_handler(&self) {
let ct = self.ct.clone();
tokio::spawn(async move {
Expand Down
12 changes: 8 additions & 4 deletions iris-mpc-cpu/src/execution/hawk_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use tokio::{
join,
sync::{mpsc, oneshot, RwLock, RwLockWriteGuard},
};
use tokio_util::sync::CancellationToken;

pub type GraphStore = graph_store::GraphPg<Aby3Store>;
pub type GraphTx<'a> = graph_store::GraphTx<'a, Aby3Store>;
Expand Down Expand Up @@ -292,9 +293,10 @@ impl HawkInsertPlan {
}

impl HawkActor {
pub async fn from_cli(args: &HawkArgs) -> Result<Self> {
pub async fn from_cli(args: &HawkArgs, ct: CancellationToken) -> Result<Self> {
Self::from_cli_with_graph_and_store(
args,
ct,
[(); 2].map(|_| GraphMem::new()),
[(); 2].map(|_| Aby3Store::new_storage(None)),
)
Expand All @@ -303,6 +305,7 @@ impl HawkActor {

pub async fn from_cli_with_graph_and_store(
args: &HawkArgs,
ct: CancellationToken,
graph: BothEyes<GraphMem<Aby3VectorRef>>,
iris_store: BothEyes<Aby3SharedIrises>,
) -> Result<Self> {
Expand All @@ -326,7 +329,8 @@ impl HawkActor {
let my_index = args.party_index;

let networking =
build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?;
build_network_handle(args, ct, &identities, SessionGroups::N_SESSIONS_PER_REQUEST)
.await?;
let graph_store = graph.map(GraphMem::to_arc);
let iris_store = iris_store.map(SharedIrises::to_arc);
let workers_handle = [LEFT, RIGHT]
Expand Down Expand Up @@ -1653,7 +1657,7 @@ impl HawkHandle {

pub async fn hawk_main(args: HawkArgs) -> Result<HawkHandle> {
println!("🦅 Starting Hawk node {}", args.party_index);
let hawk_actor = HawkActor::from_cli(&args).await?;
let hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
HawkHandle::new(hawk_actor).await
}

Expand Down Expand Up @@ -2049,7 +2053,7 @@ mod tests_db {
disable_persistence: false,
tls: None,
};
let mut hawk_actor = HawkActor::from_cli(&args).await?;
let mut hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
let (_, graph_loader) = hawk_actor.as_iris_loader().await;
graph_loader.load_graph_store(&graph_store, 2).await?;

Expand Down
3 changes: 2 additions & 1 deletion iris-mpc-cpu/src/execution/hawk_main/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use itertools::Itertools;
use rand::SeedableRng;
use std::{sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;

use crate::{
execution::local::get_free_local_addresses,
Expand Down Expand Up @@ -42,7 +43,7 @@ pub async fn setup_hawk_actors() -> Result<Vec<HawkActor>> {
// Make the test async.
sleep(Duration::from_millis(index as u64)).await;

HawkActor::from_cli(&args).await
HawkActor::from_cli(&args, CancellationToken::new()).await
}
};

Expand Down
173 changes: 144 additions & 29 deletions iris-mpc-cpu/src/network/tcp/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@ use async_trait::async_trait;
use bytes::BytesMut;
use eyre::Result;
use iris_mpc_common::fast_metrics::FastHistogram;
use std::io;
use std::{collections::HashMap, time::Instant};
use std::{
collections::HashMap,
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Instant,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
sync::{
mpsc::{self, error::TryRecvError, UnboundedReceiver, UnboundedSender},
oneshot,
oneshot, Mutex,
},
};
use tokio_util::sync::CancellationToken;

const BUFFER_CAPACITY: usize = 32 * 1024;
const READ_BUF_SIZE: usize = 2 * 1024 * 1024;
Expand Down Expand Up @@ -69,21 +77,26 @@ impl<T: NetworkConnection + 'static> TcpNetworkHandle<T> {
reconnector: Reconnector<T>,
connections: PeerConnections<T>,
config: TcpConfig,
ct: CancellationToken,
) -> Self {
let peers = connections.keys().cloned().collect();
let mut ch_map = HashMap::new();
let conn_state = ConnectionState::new();
for (peer_id, connections) in connections {
let mut m = HashMap::new();
for (stream_id, connection) in connections {
let rc = reconnector.clone();
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
m.insert(stream_id, cmd_tx);

let ct2 = ct.clone();
tokio::spawn(manage_connection(
connection,
rc,
config.get_sessions_for_stream(&stream_id),
cmd_rx,
ct2,
conn_state.clone(),
));
}
ch_map.insert(peer_id.clone(), m);
Expand Down Expand Up @@ -220,6 +233,8 @@ async fn manage_connection<T: NetworkConnection>(
reconnector: Reconnector<T>,
num_sessions: usize,
mut cmd_ch: UnboundedReceiver<Cmd>,
ct: CancellationToken,
conn_state: ConnectionState,
) {
let Connection {
peer,
Expand Down Expand Up @@ -257,35 +272,55 @@ async fn manage_connection<T: NetworkConnection>(
// when new sessions are requested, also tear down and stand up the forwarders.
loop {
let r_writer = writer.as_mut().expect("writer should be Some");
let r_outbound = &mut outbound_rx;
let outbound_task = async move {
let r = handle_outbound_traffic(r_writer, r_outbound, num_sessions).await;
tracing::warn!("handle_outbound_traffic exited: {r:?}");
};
let outbound_task = handle_outbound_traffic(r_writer, &mut outbound_rx, num_sessions);

let r_reader = reader.as_mut().expect("reader should be Some");
let r_inbound = &inbound_forwarder;
let inbound_task = async move {
let r = handle_inbound_traffic(r_reader, r_inbound).await;
tracing::warn!("handle_inbound_traffic exited: {r:?}");
};
let inbound_task = handle_inbound_traffic(r_reader, &inbound_forwarder);

enum Evt {
Cmd(Cmd),
Disconnected,
Shutdown,
}
let event = tokio::select! {
maybe_cmd = cmd_ch.recv() => {
match maybe_cmd {
Some(cmd) => Evt::Cmd(cmd),
None => {
tracing::info!("cmd channel closed");
// basically means the networking stack was shut down before a command was received.
if conn_state.exited() {
tracing::info!("cmd channel closed");
};
ct.cancel();
return;
}
}
}
_ = inbound_task => Evt::Disconnected,
_ = outbound_task => Evt::Disconnected,
r = inbound_task => {
if ct.is_cancelled() {
Evt::Shutdown
} else if let Err(e) = r {
if conn_state.incr_reconnect().await {
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
}
Evt::Disconnected
} else {
unreachable!();
}
},
r = outbound_task => {
if ct.is_cancelled() {
Evt::Shutdown
} else if let Err(e) = r {
if conn_state.incr_reconnect().await {
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
}
Evt::Disconnected
} else {
unreachable!();
}
},
_ = ct.cancelled() => Evt::Shutdown,
};

// update the Arcs depending on the event. wait for reconnect if needed.
Expand Down Expand Up @@ -314,21 +349,35 @@ async fn manage_connection<T: NetworkConnection>(
)
.await
{
tracing::error!("reconnect failed: {e:?}");
tracing::debug!("reconnect failed: {e:?}");
return;
};
rsp.send(Ok(())).unwrap();
}
},
Evt::Disconnected => {
tracing::info!("reconnecting to {:?}: {:?}", peer, stream_id);
tracing::debug!("reconnecting to {:?}: {:?}", peer, stream_id);
if let Err(e) =
reconnect_and_replace(&reconnector, &peer, stream_id, &mut reader, &mut writer)
.await
{
tracing::error!("reconnect failed: {e:?}");
if conn_state.exited() {
if ct.is_cancelled() {
tracing::info!("shutting down TCP/TLS networking stack");
} else {
tracing::error!("reconnect failed: {e:?}");
}
}
return;
};
} else if conn_state.decr_reconnect().await {
tracing::info!("all connections re-established");
}
}
Evt::Shutdown => {
if conn_state.exited() {
tracing::info!("shutting down TCP/TLS networking stack");
}
return;
}
}
}
Expand Down Expand Up @@ -412,17 +461,11 @@ async fn handle_outbound_traffic<T: NetworkConnection>(
}
}

if let Err(e) = write_buf(stream, &mut buf).await {
tracing::error!(error=%e, "Failed to flush buffer on outbound_rx");
return Err(e);
}
write_buf(stream, &mut buf).await?
}

if !buf.is_empty() {
if let Err(e) = write_buf(stream, &mut buf).await {
tracing::error!(error=%e, "Failed to flush buffer when outbound_rx closed");
return Err(e);
}
write_buf(stream, &mut buf).await?
}
// the channel will not receive any more commands
tracing::debug!("outbound_rx closed");
Expand Down Expand Up @@ -493,7 +536,10 @@ async fn handle_inbound_traffic<T: NetworkConnection>(
}
}
Err(e) => {
tracing::error!("failed to deserialize message: {e}");
return Err(io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize message: {e}"),
));
}
};
} else {
Expand All @@ -515,3 +561,72 @@ async fn write_buf<T: NetworkConnection>(
buf.clear();
Ok(())
}

// state which is shared by all connections. used to reduce the number of logs
// emitted upon loss of network connectivity.
#[derive(Clone)]
struct ConnectionState {
inner: Arc<ConnectionStateInner>,
}

impl ConnectionState {
fn new() -> Self {
Self {
inner: Arc::new(ConnectionStateInner::new()),
}
}

fn exited(&self) -> bool {
self.inner
.exited
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}

async fn incr_reconnect(&self) -> bool {
self.inner.reconnecting.increment().await
}

async fn decr_reconnect(&self) -> bool {
self.inner.reconnecting.decrement().await
}
}

struct ConnectionStateInner {
reconnecting: Counter,
exited: AtomicBool,
}

impl ConnectionStateInner {
fn new() -> Self {
Self {
reconnecting: Counter::new(),
exited: AtomicBool::new(false),
}
}
}

struct Counter {
num: Mutex<usize>,
}

impl Counter {
fn new() -> Self {
Self { num: Mutex::new(0) }
}

// returns true if num was zero
async fn increment(&self) -> bool {
let mut l = self.num.lock().await;
*l += 1;
*l == 1
}

// returns true if num was one before decrementing
async fn decrement(&self) -> bool {
let mut l = self.num.lock().await;
let r = *l == 1;
*l = l.saturating_sub(1);
r
}
}
Loading
Loading