diff --git a/deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml b/deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml index a3544e9c7..7308ea2fb 100644 --- a/deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml +++ b/deploy/dev/ampc-hnsw-0-dev/values-ampc-hnsw.yaml @@ -55,7 +55,7 @@ env: value: "30" - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "1240" + value: "1800" - name: SMPC__HAWK_REQUEST_PARALLELISM value: "32" diff --git a/deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml b/deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml index eb22fddb3..55bf46f67 100644 --- a/deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml +++ b/deploy/dev/ampc-hnsw-1-dev/values-ampc-hnsw.yaml @@ -55,7 +55,7 @@ env: value: "30" - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "1240" + value: "1800" - name: SMPC__HAWK_REQUEST_PARALLELISM value: "32" diff --git a/deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml b/deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml index 9c4f79033..8196eb68e 100644 --- a/deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml +++ b/deploy/dev/ampc-hnsw-2-dev/values-ampc-hnsw.yaml @@ -55,7 +55,7 @@ env: value: "30" - name: SMPC__PROCESSING_TIMEOUT_SECS - value: "1240" + value: "1800" - name: SMPC__HAWK_REQUEST_PARALLELISM value: "32" diff --git a/deploy/dev/common-values-ampc-hnsw.yaml b/deploy/dev/common-values-ampc-hnsw.yaml index f95ae02ab..a7736bbe1 100644 --- a/deploy/dev/common-values-ampc-hnsw.yaml +++ b/deploy/dev/common-values-ampc-hnsw.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/iris-mpc-cpu:d3a394f65b2f99a76111441373077b1c6dcf254f" +image: "ghcr.io/worldcoin/iris-mpc-cpu:04dc776bc0a4c046e9ecbaad797e8605328cb121" environment: dev replicaCount: 1 diff --git a/iris-mpc-common/src/helpers/shutdown_handler.rs b/iris-mpc-common/src/helpers/shutdown_handler.rs index b98db96e4..f6a159417 100644 --- a/iris-mpc-common/src/helpers/shutdown_handler.rs +++ b/iris-mpc-common/src/helpers/shutdown_handler.rs @@ -36,6 +36,10 @@ impl ShutdownHandler { self.ct.cancelled().await } + pub fn get_cancellation_token(&self) -> CancellationToken { + self.ct.clone() + } + pub async fn register_signal_handler(&self) { let ct = self.ct.clone(); tokio::spawn(async move { diff --git a/iris-mpc-cpu/src/execution/hawk_main.rs b/iris-mpc-cpu/src/execution/hawk_main.rs index 7191897ff..9ce74909d 100644 --- a/iris-mpc-cpu/src/execution/hawk_main.rs +++ b/iris-mpc-cpu/src/execution/hawk_main.rs @@ -67,6 +67,7 @@ use tokio::{ join, sync::{mpsc, oneshot, RwLock, RwLockWriteGuard}, }; +use tokio_util::sync::CancellationToken; pub type GraphStore = graph_store::GraphPg; pub type GraphTx<'a> = graph_store::GraphTx<'a, Aby3Store>; @@ -292,9 +293,10 @@ impl HawkInsertPlan { } impl HawkActor { - pub async fn from_cli(args: &HawkArgs) -> Result { + pub async fn from_cli(args: &HawkArgs, ct: CancellationToken) -> Result { Self::from_cli_with_graph_and_store( args, + ct, [(); 2].map(|_| GraphMem::new()), [(); 2].map(|_| Aby3Store::new_storage(None)), ) @@ -303,6 +305,7 @@ impl HawkActor { pub async fn from_cli_with_graph_and_store( args: &HawkArgs, + ct: CancellationToken, graph: BothEyes>, iris_store: BothEyes, ) -> Result { @@ -326,7 +329,8 @@ impl HawkActor { let my_index = args.party_index; let networking = - build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?; + build_network_handle(args, ct, &identities, SessionGroups::N_SESSIONS_PER_REQUEST) + .await?; let graph_store = graph.map(GraphMem::to_arc); let iris_store = iris_store.map(SharedIrises::to_arc); let workers_handle = [LEFT, RIGHT] @@ -1653,7 +1657,7 @@ impl HawkHandle { pub async fn hawk_main(args: HawkArgs) -> Result { println!("🦅 Starting Hawk node {}", args.party_index); - let hawk_actor = HawkActor::from_cli(&args).await?; + let hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?; HawkHandle::new(hawk_actor).await } @@ -2049,7 +2053,7 @@ mod tests_db { disable_persistence: false, tls: None, }; - let mut hawk_actor = HawkActor::from_cli(&args).await?; + let mut hawk_actor = HawkActor::from_cli(&args, CancellationToken::new()).await?; let (_, graph_loader) = hawk_actor.as_iris_loader().await; graph_loader.load_graph_store(&graph_store, 2).await?; diff --git a/iris-mpc-cpu/src/execution/hawk_main/test_utils.rs b/iris-mpc-cpu/src/execution/hawk_main/test_utils.rs index 49b58afb4..b6e195b9d 100644 --- a/iris-mpc-cpu/src/execution/hawk_main/test_utils.rs +++ b/iris-mpc-cpu/src/execution/hawk_main/test_utils.rs @@ -11,6 +11,7 @@ use itertools::Itertools; use rand::SeedableRng; use std::{sync::Arc, time::Duration}; use tokio::time::sleep; +use tokio_util::sync::CancellationToken; use crate::{ execution::local::get_free_local_addresses, @@ -42,7 +43,7 @@ pub async fn setup_hawk_actors() -> Result> { // Make the test async. sleep(Duration::from_millis(index as u64)).await; - HawkActor::from_cli(&args).await + HawkActor::from_cli(&args, CancellationToken::new()).await } }; diff --git a/iris-mpc-cpu/src/network/tcp/handle.rs b/iris-mpc-cpu/src/network/tcp/handle.rs index ec583859c..df87976bc 100644 --- a/iris-mpc-cpu/src/network/tcp/handle.rs +++ b/iris-mpc-cpu/src/network/tcp/handle.rs @@ -13,15 +13,23 @@ use async_trait::async_trait; use bytes::BytesMut; use eyre::Result; use iris_mpc_common::fast_metrics::FastHistogram; -use std::io; -use std::{collections::HashMap, time::Instant}; +use std::{ + collections::HashMap, + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Instant, +}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}, sync::{ mpsc::{self, error::TryRecvError, UnboundedReceiver, UnboundedSender}, - oneshot, + oneshot, Mutex, }, }; +use tokio_util::sync::CancellationToken; const BUFFER_CAPACITY: usize = 32 * 1024; const READ_BUF_SIZE: usize = 2 * 1024 * 1024; @@ -69,9 +77,11 @@ impl TcpNetworkHandle { reconnector: Reconnector, connections: PeerConnections, config: TcpConfig, + ct: CancellationToken, ) -> Self { let peers = connections.keys().cloned().collect(); let mut ch_map = HashMap::new(); + let conn_state = ConnectionState::new(); for (peer_id, connections) in connections { let mut m = HashMap::new(); for (stream_id, connection) in connections { @@ -79,11 +89,14 @@ impl TcpNetworkHandle { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); m.insert(stream_id, cmd_tx); + let ct2 = ct.clone(); tokio::spawn(manage_connection( connection, rc, config.get_sessions_for_stream(&stream_id), cmd_rx, + ct2, + conn_state.clone(), )); } ch_map.insert(peer_id.clone(), m); @@ -220,6 +233,8 @@ async fn manage_connection( reconnector: Reconnector, num_sessions: usize, mut cmd_ch: UnboundedReceiver, + ct: CancellationToken, + conn_state: ConnectionState, ) { let Connection { peer, @@ -257,35 +272,55 @@ async fn manage_connection( // when new sessions are requested, also tear down and stand up the forwarders. loop { let r_writer = writer.as_mut().expect("writer should be Some"); - let r_outbound = &mut outbound_rx; - let outbound_task = async move { - let r = handle_outbound_traffic(r_writer, r_outbound, num_sessions).await; - tracing::warn!("handle_outbound_traffic exited: {r:?}"); - }; + let outbound_task = handle_outbound_traffic(r_writer, &mut outbound_rx, num_sessions); let r_reader = reader.as_mut().expect("reader should be Some"); - let r_inbound = &inbound_forwarder; - let inbound_task = async move { - let r = handle_inbound_traffic(r_reader, r_inbound).await; - tracing::warn!("handle_inbound_traffic exited: {r:?}"); - }; + let inbound_task = handle_inbound_traffic(r_reader, &inbound_forwarder); enum Evt { Cmd(Cmd), Disconnected, + Shutdown, } let event = tokio::select! { maybe_cmd = cmd_ch.recv() => { match maybe_cmd { Some(cmd) => Evt::Cmd(cmd), None => { - tracing::info!("cmd channel closed"); + // basically means the networking stack was shut down before a command was received. + if conn_state.exited() { + tracing::info!("cmd channel closed"); + }; + ct.cancel(); return; } } } - _ = inbound_task => Evt::Disconnected, - _ = outbound_task => Evt::Disconnected, + r = inbound_task => { + if ct.is_cancelled() { + Evt::Shutdown + } else if let Err(e) = r { + if conn_state.incr_reconnect().await { + tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting..."); + } + Evt::Disconnected + } else { + unreachable!(); + } + }, + r = outbound_task => { + if ct.is_cancelled() { + Evt::Shutdown + } else if let Err(e) = r { + if conn_state.incr_reconnect().await { + tracing::error!(e=%e, "TCP/TLS connection closed unexpectedly. reconnecting..."); + } + Evt::Disconnected + } else { + unreachable!(); + } + }, + _ = ct.cancelled() => Evt::Shutdown, }; // update the Arcs depending on the event. wait for reconnect if needed. @@ -314,21 +349,35 @@ async fn manage_connection( ) .await { - tracing::error!("reconnect failed: {e:?}"); + tracing::debug!("reconnect failed: {e:?}"); return; }; rsp.send(Ok(())).unwrap(); } }, Evt::Disconnected => { - tracing::info!("reconnecting to {:?}: {:?}", peer, stream_id); + tracing::debug!("reconnecting to {:?}: {:?}", peer, stream_id); if let Err(e) = reconnect_and_replace(&reconnector, &peer, stream_id, &mut reader, &mut writer) .await { - tracing::error!("reconnect failed: {e:?}"); + if conn_state.exited() { + if ct.is_cancelled() { + tracing::info!("shutting down TCP/TLS networking stack"); + } else { + tracing::error!("reconnect failed: {e:?}"); + } + } return; - }; + } else if conn_state.decr_reconnect().await { + tracing::info!("all connections re-established"); + } + } + Evt::Shutdown => { + if conn_state.exited() { + tracing::info!("shutting down TCP/TLS networking stack"); + } + return; } } } @@ -412,17 +461,11 @@ async fn handle_outbound_traffic( } } - if let Err(e) = write_buf(stream, &mut buf).await { - tracing::error!(error=%e, "Failed to flush buffer on outbound_rx"); - return Err(e); - } + write_buf(stream, &mut buf).await? } if !buf.is_empty() { - if let Err(e) = write_buf(stream, &mut buf).await { - tracing::error!(error=%e, "Failed to flush buffer when outbound_rx closed"); - return Err(e); - } + write_buf(stream, &mut buf).await? } // the channel will not receive any more commands tracing::debug!("outbound_rx closed"); @@ -493,7 +536,10 @@ async fn handle_inbound_traffic( } } Err(e) => { - tracing::error!("failed to deserialize message: {e}"); + return Err(io::Error::new( + io::ErrorKind::Other, + format!("failed to deserialize message: {e}"), + )); } }; } else { @@ -515,3 +561,72 @@ async fn write_buf( buf.clear(); Ok(()) } + +// state which is shared by all connections. used to reduce the number of logs +// emitted upon loss of network connectivity. +#[derive(Clone)] +struct ConnectionState { + inner: Arc, +} + +impl ConnectionState { + fn new() -> Self { + Self { + inner: Arc::new(ConnectionStateInner::new()), + } + } + + fn exited(&self) -> bool { + self.inner + .exited + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + } + + async fn incr_reconnect(&self) -> bool { + self.inner.reconnecting.increment().await + } + + async fn decr_reconnect(&self) -> bool { + self.inner.reconnecting.decrement().await + } +} + +struct ConnectionStateInner { + reconnecting: Counter, + exited: AtomicBool, +} + +impl ConnectionStateInner { + fn new() -> Self { + Self { + reconnecting: Counter::new(), + exited: AtomicBool::new(false), + } + } +} + +struct Counter { + num: Mutex, +} + +impl Counter { + fn new() -> Self { + Self { num: Mutex::new(0) } + } + + // returns true if num was zero + async fn increment(&self) -> bool { + let mut l = self.num.lock().await; + *l += 1; + *l == 1 + } + + // returns true if num was one before decrementing + async fn decrement(&self) -> bool { + let mut l = self.num.lock().await; + let r = *l == 1; + *l = l.saturating_sub(1); + r + } +} diff --git a/iris-mpc-cpu/src/network/tcp/mod.rs b/iris-mpc-cpu/src/network/tcp/mod.rs index 06f2a7893..0313af087 100644 --- a/iris-mpc-cpu/src/network/tcp/mod.rs +++ b/iris-mpc-cpu/src/network/tcp/mod.rs @@ -20,6 +20,7 @@ use std::{ time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::sync::CancellationToken; pub mod config; mod data; @@ -55,6 +56,7 @@ pub trait Server: Send { pub async fn build_network_handle( args: &HawkArgs, + ct: CancellationToken, identities: &[Identity], sessions_per_request: usize, ) -> Result> { @@ -79,6 +81,33 @@ pub async fn build_network_handle( args.request_parallelism * sessions_per_request, ); + // PeerConnectionBuilder is generic over listener and connector and don't want to use a boxed trait to + // reduce code duplication. instead, use a macro which makes use of local variables. + macro_rules! build_network_handle { + ($listener:expr, $connector:expr) => {{ + let connection_builder = PeerConnectionBuilder::new( + my_identity, + tcp_config.clone(), + $listener, + $connector, + ct.clone(), + ) + .await?; + + for (identity, url) in + izip!(identities, &args.addresses).filter(|(_, address)| address != &my_address) + { + connection_builder + .include_peer(identity.clone(), url.clone()) + .await?; + } + + let (reconnector, connections) = connection_builder.build().await?; + let networking = TcpNetworkHandle::new(reconnector, connections, tcp_config, ct); + Ok(Box::new(networking)) + }}; + } + if let Some(tls) = args.tls.as_ref() { tracing::info!( "Building NetworkHandle, with TLS, from configs: {:?} {:?}", @@ -92,22 +121,7 @@ pub async fn build_network_handle( let listener = BoxTcpServer(TcpServer::new(my_addr).await?); let connector = BoxTlsClient(TlsClient::new_with_ca_certs(&root_certs).await?); - let connection_builder = - PeerConnectionBuilder::new(my_identity, tcp_config.clone(), listener, connector) - .await?; - - // Connect to other players. - for (identity, url) in - izip!(identities, &args.addresses).filter(|(_, address)| address != &my_address) - { - connection_builder - .include_peer(identity.clone(), url.clone()) - .await?; - } - - let (reconnector, connections) = connection_builder.build().await?; - let networking = TcpNetworkHandle::new(reconnector, connections, tcp_config); - Ok(Box::new(networking)) + build_network_handle!(listener, connector) } else { tracing::info!("Running in full app TLS mode."); if tls.private_key.is_none() || tls.leaf_cert.is_none() { @@ -127,22 +141,7 @@ pub async fn build_network_handle( let listener = TlsServer::new(my_addr, private_key, leaf_cert, &root_certs).await?; let connector = TlsClient::new_with_ca_certs(&root_certs).await?; - - let connection_builder = - PeerConnectionBuilder::new(my_identity, tcp_config.clone(), listener, connector) - .await?; - // Connect to other players. - for (identity, url) in - izip!(identities, &args.addresses).filter(|(_, address)| address != &my_address) - { - connection_builder - .include_peer(identity.clone(), url.clone()) - .await?; - } - - let (reconnector, connections) = connection_builder.build().await?; - let networking = TcpNetworkHandle::new(reconnector, connections, tcp_config); - Ok(Box::new(networking)) + build_network_handle!(listener, connector) } } else { tracing::info!( @@ -151,22 +150,7 @@ pub async fn build_network_handle( ); let listener = BoxTcpServer(TcpServer::new(my_addr).await?); let connector = BoxTcpClient(TcpClient::default()); - let connection_builder = - PeerConnectionBuilder::new(my_identity, tcp_config.clone(), listener, connector) - .await?; - - // Connect to other players. - for (identity, url) in - izip!(identities, &args.addresses).filter(|(_, address)| address != &my_address) - { - connection_builder - .include_peer(identity.clone(), url.clone()) - .await?; - } - - let (reconnector, connections) = connection_builder.build().await?; - let networking = TcpNetworkHandle::new(reconnector, connections, tcp_config); - Ok(Box::new(networking)) + build_network_handle!(listener, connector) } } @@ -185,6 +169,7 @@ pub mod testing { use itertools::izip; use std::{collections::HashSet, net::SocketAddr, sync::LazyLock, time::Duration}; use tokio::{net::TcpStream, sync::Mutex, time::sleep}; + use tokio_util::sync::CancellationToken; use crate::{ execution::player::Identity, @@ -237,6 +222,7 @@ pub mod testing { // Create NetworkHandles for each party let mut builders = Vec::with_capacity(parties.len()); let connector = TcpClient::default(); + let ct = CancellationToken::new(); for (party, addr) in izip!(parties.iter(), addresses.iter()) { let listener = TcpServer::new(*addr).await?; builders.push( @@ -245,6 +231,7 @@ pub mod testing { config.clone(), listener, connector.clone(), + ct.clone(), ) .await?, ); @@ -273,9 +260,10 @@ pub mod testing { } tracing::debug!("Players connected to each other"); + let ct = CancellationToken::new(); let mut handles = vec![]; for (r, c) in connections { - handles.push(TcpNetworkHandle::new(r, c, config.clone())); + handles.push(TcpNetworkHandle::new(r, c, config.clone(), ct.clone())); } tracing::debug!("waiting for make_sessions to complete"); diff --git a/iris-mpc-cpu/src/network/tcp/networking/connection_builder.rs b/iris-mpc-cpu/src/network/tcp/networking/connection_builder.rs index d5d34645e..465d31c37 100644 --- a/iris-mpc-cpu/src/network/tcp/networking/connection_builder.rs +++ b/iris-mpc-cpu/src/network/tcp/networking/connection_builder.rs @@ -87,8 +87,9 @@ where tcp_config: TcpConfig, listener: S, connector: C, + ct: CancellationToken, ) -> Result { - let cmd_tx = Worker::spawn(id.clone(), listener, connector).await?; + let cmd_tx = Worker::spawn(id.clone(), listener, connector, ct).await?; Ok(Self { id, tcp_config, @@ -195,13 +196,14 @@ where id: Identity, listener: S, connector: C, + ct: CancellationToken, ) -> io::Result>> { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let (pending_tx, pending_rx) = mpsc::unbounded_channel::>(); let mut worker = Self { id: id.clone(), cmd_rx, - ct: CancellationToken::new(), + ct, peer_addrs: HashMap::new(), connector, pending_connections: HashMap::new(), @@ -475,6 +477,8 @@ where peer, stream_id ); + + // putting ThreadRng in here makes the future not Send sleep(Duration::from_millis(delay_ms)).await; loop { let r = tokio::select! { @@ -491,13 +495,13 @@ where } else { let pending = Connection::new(peer, stream, stream_id); if pending_tx.send(pending).is_err() { - tracing::error!("accept loop receiver dropped"); + tracing::debug!("accept loop receiver dropped"); } break; } } Err(e) => { - tracing::warn!(%e, "dial {:?} failed, retrying", url.clone()); + tracing::debug!(%e, "dial {:?} failed, retrying", url.clone()); } }; sleep(Duration::from_secs(retry_sec)).await; diff --git a/iris-mpc-cpu/tests/e2e-anon-stats.rs b/iris-mpc-cpu/tests/e2e-anon-stats.rs index 15d836c82..b8c81626b 100644 --- a/iris-mpc-cpu/tests/e2e-anon-stats.rs +++ b/iris-mpc-cpu/tests/e2e-anon-stats.rs @@ -16,6 +16,7 @@ use iris_mpc_cpu::{ }; use rand::{random, rngs::StdRng, SeedableRng}; use std::{collections::HashMap, env, sync::Arc, time::Duration}; +use tokio_util::sync::CancellationToken; use tracing_subscriber::{fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt}; const DB_SIZE: usize = 1000; @@ -147,7 +148,9 @@ async fn start_hawk_node( ); let (graph, iris_store) = create_graph_from_plain_dbs(args.party_index, db_seed, left_db, right_db, ¶ms).await?; - let hawk_actor = HawkActor::from_cli_with_graph_and_store(args, graph, iris_store).await?; + let hawk_actor = + HawkActor::from_cli_with_graph_and_store(args, CancellationToken::new(), graph, iris_store) + .await?; let handle = HawkHandle::new(hawk_actor).await?; diff --git a/iris-mpc-cpu/tests/e2e.rs b/iris-mpc-cpu/tests/e2e.rs index 857eb4888..3f5cbdf2f 100644 --- a/iris-mpc-cpu/tests/e2e.rs +++ b/iris-mpc-cpu/tests/e2e.rs @@ -16,6 +16,7 @@ use iris_mpc_cpu::{ }; use rand::{rngs::StdRng, SeedableRng}; use std::{collections::HashMap, sync::Arc, time::Duration}; +use tokio_util::sync::CancellationToken; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const DB_SIZE: usize = 1000; @@ -142,7 +143,9 @@ async fn start_hawk_node( ); let (graph, iris_store) = create_graph_from_plain_dbs(args.party_index, left_db, right_db, ¶ms).await?; - let hawk_actor = HawkActor::from_cli_with_graph_and_store(args, graph, iris_store).await?; + let hawk_actor = + HawkActor::from_cli_with_graph_and_store(args, CancellationToken::new(), graph, iris_store) + .await?; let handle = HawkHandle::new(hawk_actor).await?; diff --git a/iris-mpc-upgrade-hawk/src/genesis/mod.rs b/iris-mpc-upgrade-hawk/src/genesis/mod.rs index fd0efe0bd..f7205071e 100644 --- a/iris-mpc-upgrade-hawk/src/genesis/mod.rs +++ b/iris-mpc-upgrade-hawk/src/genesis/mod.rs @@ -366,7 +366,7 @@ async fn exec_setup( log_info(String::from("Store consistency checks OK")); // Initialise HNSW graph from previously indexed. - let mut hawk_actor = get_hawk_actor(config).await?; + let mut hawk_actor = get_hawk_actor(config, &shutdown_handler).await?; init_graph_from_stores( config, &iris_store, @@ -867,7 +867,10 @@ pub async fn exec_use_backup_as_source( /// /// * `config` - Application configuration instance. /// -async fn get_hawk_actor(config: &Config) -> Result { +async fn get_hawk_actor( + config: &Config, + shutdown_handler: &Arc, +) -> Result { let node_addresses: Vec = config .node_hostnames .iter() @@ -896,7 +899,7 @@ async fn get_hawk_actor(config: &Config) -> Result { hawk_args.party_index, node_addresses )); - HawkActor::from_cli(&hawk_args).await + HawkActor::from_cli(&hawk_args, shutdown_handler.get_cancellation_token()).await } /// Returns service clients used downstream. diff --git a/iris-mpc/src/server/mod.rs b/iris-mpc/src/server/mod.rs index aedef77f0..601ee5c12 100644 --- a/iris-mpc/src/server/mod.rs +++ b/iris-mpc/src/server/mod.rs @@ -131,7 +131,7 @@ pub async fn server_main(config: Config) -> Result<()> { return Ok(()); } - let mut hawk_actor = init_hawk_actor(&config).await?; + let mut hawk_actor = init_hawk_actor(&config, &shutdown_handler).await?; load_database( &config, @@ -389,7 +389,10 @@ async fn sync_sqs_queues( /// Initialize main Hawk actor process for handling query batches using HNSW /// approximate k-nearest neighbors graph search. -async fn init_hawk_actor(config: &Config) -> Result { +async fn init_hawk_actor( + config: &Config, + shutdown_handler: &Arc, +) -> Result { let node_addresses: Vec = config .node_hostnames .iter() @@ -419,7 +422,7 @@ async fn init_hawk_actor(config: &Config) -> Result { node_addresses ); - HawkActor::from_cli(&hawk_args).await + HawkActor::from_cli(&hawk_args, shutdown_handler.get_cancellation_token()).await } /// Loads iris code shares & HNSW graph from Postgres and/or S3.