Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ env:
value: "30"

- name: SMPC__PROCESSING_TIMEOUT_SECS
value: "1240"
value: "1800"

- name: SMPC__HAWK_REQUEST_PARALLELISM
value: "32"
Expand Down
2 changes: 1 addition & 1 deletion deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ env:
value: "30"

- name: SMPC__PROCESSING_TIMEOUT_SECS
value: "1240"
value: "1800"

- name: SMPC__HAWK_REQUEST_PARALLELISM
value: "32"
Expand Down
2 changes: 1 addition & 1 deletion deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ env:
value: "30"

- name: SMPC__PROCESSING_TIMEOUT_SECS
value: "1240"
value: "1800"

- name: SMPC__HAWK_REQUEST_PARALLELISM
value: "32"
Expand Down
2 changes: 1 addition & 1 deletion deploy/dev/common-values-ampc-hnsw.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc-cpu:d3a394f65b2f99a76111441373077b1c6dcf254f"
image: "ghcr.io/worldcoin/iris-mpc-cpu:04dc776bc0a4c046e9ecbaad797e8605328cb121"

environment: dev
replicaCount: 1
Expand Down
4 changes: 4 additions & 0 deletions iris-mpc-common/src/helpers/shutdown_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ impl ShutdownHandler {
self.ct.cancelled().await
}

pub fn get_cancellation_token(&self) -> CancellationToken {
self.ct.clone()
}

pub async fn register_signal_handler(&self) {
let ct = self.ct.clone();
tokio::spawn(async move {
Expand Down
12 changes: 8 additions & 4 deletions iris-mpc-cpu/src/execution/hawk_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use tokio::{
join,
sync::{mpsc, oneshot, RwLock, RwLockWriteGuard},
};
use tokio_util::sync::CancellationToken;

pub type GraphStore = graph_store::GraphPg<Aby3Store>;
pub type GraphTx<'a> = graph_store::GraphTx<'a, Aby3Store>;
Expand Down Expand Up @@ -292,9 +293,10 @@ impl HawkInsertPlan {
}

impl HawkActor {
pub async fn from_cli(args: &HawkArgs) -> Result<Self> {
pub async fn from_cli(args: &HawkArgs, ct: CancellationToken) -> Result<Self> {
Self::from_cli_with_graph_and_store(
args,
ct,
[(); 2].map(|_| GraphMem::new()),
[(); 2].map(|_| Aby3Store::new_storage(None)),
)
Expand All @@ -303,6 +305,7 @@ impl HawkActor {

pub async fn from_cli_with_graph_and_store(
args: &HawkArgs,
ct: CancellationToken,
graph: BothEyes<GraphMem<Aby3VectorRef>>,
iris_store: BothEyes<Aby3SharedIrises>,
) -> Result<Self> {
Expand All @@ -326,7 +329,8 @@ impl HawkActor {
let my_index = args.party_index;

let networking =
build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?;
build_network_handle(args, ct, &identities, SessionGroups::N_SESSIONS_PER_REQUEST)
.await?;
let graph_store = graph.map(GraphMem::to_arc);
let iris_store = iris_store.map(SharedIrises::to_arc);
let workers_handle = [LEFT, RIGHT]
Expand Down Expand Up @@ -1653,7 +1657,7 @@ impl HawkHandle {

pub async fn hawk_main(args: HawkArgs) -> Result<HawkHandle> {
println!("🦅 Starting Hawk node {}", args.party_index);
let hawk_actor = HawkActor::from_cli(&args).await?;
let hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
HawkHandle::new(hawk_actor).await
}

Expand Down Expand Up @@ -2049,7 +2053,7 @@ mod tests_db {
disable_persistence: false,
tls: None,
};
let mut hawk_actor = HawkActor::from_cli(&args).await?;
let mut hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?;
let (_, graph_loader) = hawk_actor.as_iris_loader().await;
graph_loader.load_graph_store(&graph_store, 2).await?;

Expand Down
3 changes: 2 additions & 1 deletion iris-mpc-cpu/src/execution/hawk_main/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use itertools::Itertools;
use rand::SeedableRng;
use std::{sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;

use crate::{
execution::local::get_free_local_addresses,
Expand Down Expand Up @@ -42,7 +43,7 @@ pub async fn setup_hawk_actors() -> Result<Vec<HawkActor>> {
// Make the test async.
sleep(Duration::from_millis(index as u64)).await;

HawkActor::from_cli(&args).await
HawkActor::from_cli(&args, CancellationToken::new()).await
}
};

Expand Down
173 changes: 144 additions & 29 deletions iris-mpc-cpu/src/network/tcp/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@ use async_trait::async_trait;
use bytes::BytesMut;
use eyre::Result;
use iris_mpc_common::fast_metrics::FastHistogram;
use std::io;
use std::{collections::HashMap, time::Instant};
use std::{
collections::HashMap,
io,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Instant,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf},
sync::{
mpsc::{self, error::TryRecvError, UnboundedReceiver, UnboundedSender},
oneshot,
oneshot, Mutex,
},
};
use tokio_util::sync::CancellationToken;

const BUFFER_CAPACITY: usize = 32 * 1024;
const READ_BUF_SIZE: usize = 2 * 1024 * 1024;
Expand Down Expand Up @@ -69,21 +77,26 @@ impl<T: NetworkConnection + 'static> TcpNetworkHandle<T> {
reconnector: Reconnector<T>,
connections: PeerConnections<T>,
config: TcpConfig,
ct: CancellationToken,
) -> Self {
let peers = connections.keys().cloned().collect();
let mut ch_map = HashMap::new();
let conn_state = ConnectionState::new();
for (peer_id, connections) in connections {
let mut m = HashMap::new();
for (stream_id, connection) in connections {
let rc = reconnector.clone();
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
m.insert(stream_id, cmd_tx);

let ct2 = ct.clone();
tokio::spawn(manage_connection(
connection,
rc,
config.get_sessions_for_stream(&stream_id),
cmd_rx,
ct2,
conn_state.clone(),
));
}
ch_map.insert(peer_id.clone(), m);
Expand Down Expand Up @@ -220,6 +233,8 @@ async fn manage_connection<T: NetworkConnection>(
reconnector: Reconnector<T>,
num_sessions: usize,
mut cmd_ch: UnboundedReceiver<Cmd>,
ct: CancellationToken,
conn_state: ConnectionState,
) {
let Connection {
peer,
Expand Down Expand Up @@ -257,35 +272,55 @@ async fn manage_connection<T: NetworkConnection>(
// when new sessions are requested, also tear down and stand up the forwarders.
loop {
let r_writer = writer.as_mut().expect("writer should be Some");
let r_outbound = &mut outbound_rx;
let outbound_task = async move {
let r = handle_outbound_traffic(r_writer, r_outbound, num_sessions).await;
tracing::warn!("handle_outbound_traffic exited: {r:?}");
};
let outbound_task = handle_outbound_traffic(r_writer, &mut outbound_rx, num_sessions);

let r_reader = reader.as_mut().expect("reader should be Some");
let r_inbound = &inbound_forwarder;
let inbound_task = async move {
let r = handle_inbound_traffic(r_reader, r_inbound).await;
tracing::warn!("handle_inbound_traffic exited: {r:?}");
};
let inbound_task = handle_inbound_traffic(r_reader, &inbound_forwarder);

enum Evt {
Cmd(Cmd),
Disconnected,
Shutdown,
}
let event = tokio::select! {
maybe_cmd = cmd_ch.recv() => {
match maybe_cmd {
Some(cmd) => Evt::Cmd(cmd),
None => {
tracing::info!("cmd channel closed");
// basically means the networking stack was shut down before a command was received.
if conn_state.exited() {
tracing::info!("cmd channel closed");
};
ct.cancel();
return;
}
}
}
_ = inbound_task => Evt::Disconnected,
_ = outbound_task => Evt::Disconnected,
r = inbound_task => {
if ct.is_cancelled() {
Evt::Shutdown
} else if let Err(e) = r {
if conn_state.incr_reconnect().await {
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
}
Evt::Disconnected
} else {
unreachable!();
}
},
r = outbound_task => {
if ct.is_cancelled() {
Evt::Shutdown
} else if let Err(e) = r {
if conn_state.incr_reconnect().await {
tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting...");
}
Evt::Disconnected
} else {
unreachable!();
}
},
_ = ct.cancelled() => Evt::Shutdown,
};

// update the Arcs depending on the event. wait for reconnect if needed.
Expand Down Expand Up @@ -314,21 +349,35 @@ async fn manage_connection<T: NetworkConnection>(
)
.await
{
tracing::error!("reconnect failed: {e:?}");
tracing::debug!("reconnect failed: {e:?}");
return;
};
rsp.send(Ok(())).unwrap();
}
},
Evt::Disconnected => {
tracing::info!("reconnecting to {:?}: {:?}", peer, stream_id);
tracing::debug!("reconnecting to {:?}: {:?}", peer, stream_id);
if let Err(e) =
reconnect_and_replace(&reconnector, &peer, stream_id, &mut reader, &mut writer)
.await
{
tracing::error!("reconnect failed: {e:?}");
if conn_state.exited() {
if ct.is_cancelled() {
tracing::info!("shutting down TCP/TLS networking stack");
} else {
tracing::error!("reconnect failed: {e:?}");
}
}
return;
};
} else if conn_state.decr_reconnect().await {
tracing::info!("all connections re-established");
}
}
Evt::Shutdown => {
if conn_state.exited() {
tracing::info!("shutting down TCP/TLS networking stack");
}
return;
}
}
}
Expand Down Expand Up @@ -412,17 +461,11 @@ async fn handle_outbound_traffic<T: NetworkConnection>(
}
}

if let Err(e) = write_buf(stream, &mut buf).await {
tracing::error!(error=%e, "Failed to flush buffer on outbound_rx");
return Err(e);
}
write_buf(stream, &mut buf).await?
}

if !buf.is_empty() {
if let Err(e) = write_buf(stream, &mut buf).await {
tracing::error!(error=%e, "Failed to flush buffer when outbound_rx closed");
return Err(e);
}
write_buf(stream, &mut buf).await?
}
// the channel will not receive any more commands
tracing::debug!("outbound_rx closed");
Expand Down Expand Up @@ -493,7 +536,10 @@ async fn handle_inbound_traffic<T: NetworkConnection>(
}
}
Err(e) => {
tracing::error!("failed to deserialize message: {e}");
return Err(io::Error::new(
io::ErrorKind::Other,
format!("failed to deserialize message: {e}"),
));
}
};
} else {
Expand All @@ -515,3 +561,72 @@ async fn write_buf<T: NetworkConnection>(
buf.clear();
Ok(())
}

// state which is shared by all connections. used to reduce the number of logs
// emitted upon loss of network connectivity.
#[derive(Clone)]
struct ConnectionState {
inner: Arc<ConnectionStateInner>,
}

impl ConnectionState {
fn new() -> Self {
Self {
inner: Arc::new(ConnectionStateInner::new()),
}
}

fn exited(&self) -> bool {
self.inner
.exited
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}

async fn incr_reconnect(&self) -> bool {
self.inner.reconnecting.increment().await
}

async fn decr_reconnect(&self) -> bool {
self.inner.reconnecting.decrement().await
}
}

struct ConnectionStateInner {
reconnecting: Counter,
exited: AtomicBool,
}

impl ConnectionStateInner {
fn new() -> Self {
Self {
reconnecting: Counter::new(),
exited: AtomicBool::new(false),
}
}
}

struct Counter {
num: Mutex<usize>,
}

impl Counter {
fn new() -> Self {
Self { num: Mutex::new(0) }
}

// returns true if num was zero
async fn increment(&self) -> bool {
let mut l = self.num.lock().await;
*l += 1;
*l == 1
}

// returns true if num was one before decrementing
async fn decrement(&self) -> bool {
let mut l = self.num.lock().await;
let r = *l == 1;
*l = l.saturating_sub(1);
r
}
}
Loading
Loading