diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index f2453ebbc..5a881f5fe 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -229,6 +229,9 @@ pub struct Config { #[serde(default)] pub hawk_prf_key: Option, + #[serde(default = "default_hawk_numa")] + pub hawk_numa: bool, + #[serde(default = "default_max_deletions_per_batch")] pub max_deletions_per_batch: usize, @@ -399,6 +402,10 @@ fn default_hnsw_param_ef_search() -> usize { 256 } +fn default_hawk_numa() -> bool { + true +} + fn default_service_ports() -> Vec { vec!["4000".to_string(); 3] } @@ -762,6 +769,7 @@ impl From for CommonConfig { hnsw_param_M, hnsw_param_ef_search, hawk_prf_key, + hawk_numa: _, // could be different for each server max_deletions_per_batch, enable_modifications_sync, enable_modifications_replay, diff --git a/iris-mpc-cpu/src/execution/hawk_main.rs b/iris-mpc-cpu/src/execution/hawk_main.rs index b5cd79fd7..2ee021c03 100644 --- a/iris-mpc-cpu/src/execution/hawk_main.rs +++ b/iris-mpc-cpu/src/execution/hawk_main.rs @@ -23,7 +23,7 @@ use crate::{ }; use clap::Parser; use eyre::{eyre, Report, Result}; -use futures::try_join; +use futures::{future::try_join_all, try_join}; use intra_batch::intra_batch_is_match; use iris_mpc_common::job::Eye; use iris_mpc_common::{ @@ -125,6 +125,9 @@ pub struct HawkArgs { #[clap(flatten)] pub tls: Option, + + #[clap(long, default_value_t = false)] + pub numa: bool, } /// HawkActor manages the state of the HNSW database and connections to other @@ -326,8 +329,8 @@ impl HawkActor { build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?; let graph_store = graph.map(GraphMem::to_arc); let iris_store = iris_store.map(SharedIrises::to_arc); - let workers_handle = - [LEFT, RIGHT].map(|side| iris_worker::init_workers(side, iris_store[side].clone())); + let workers_handle = [LEFT, RIGHT] + .map(|side| iris_worker::init_workers(side, iris_store[side].clone(), args.numa)); let bucket_statistics_left = BucketStatistics::new( args.match_distances_buffer_size, @@ -625,10 +628,7 @@ impl HawkActor { IrisLoader { party_id: self.party_id, db_size: &mut self.loader_db_size, - irises: [ - self.iris_store[0].write().await, - self.iris_store[1].write().await, - ], + iris_pools: self.workers_handle.clone(), }, GraphLoader([ self.graph_store[0].write().await, @@ -650,7 +650,17 @@ pub type Aby3SharedIrisesMut<'a> = RwLockWriteGuard<'a, Aby3SharedIrises>; pub struct IrisLoader<'a> { party_id: usize, db_size: &'a mut usize, - irises: BothEyes>, + iris_pools: BothEyes, +} + +impl IrisLoader<'_> { + pub async fn wait_completion(self) -> Result<()> { + try_join!( + self.iris_pools[LEFT].wait_completion(), + self.iris_pools[RIGHT].wait_completion(), + )?; + Ok(()) + } } #[allow(clippy::needless_lifetimes)] @@ -664,14 +674,14 @@ impl<'a> InMemoryStore for IrisLoader<'a> { right_code: &[u16], right_mask: &[u16], ) { - for (side, code, mask) in izip!( - &mut self.irises, + for (pool, code, mask) in izip!( + &self.iris_pools, [left_code, right_code], [left_mask, right_mask] ) { let iris = GaloisRingSharedIris::try_from_buffers(self.party_id, code, mask) .expect("Wrong code or mask size"); - side.insert(vector_id, iris); + pool.insert(vector_id, iris).unwrap(); } } @@ -680,8 +690,8 @@ impl<'a> InMemoryStore for IrisLoader<'a> { } fn reserve(&mut self, additional: usize) { - for side in &mut self.irises { - side.reserve(additional); + for side in &self.iris_pools { + side.reserve(additional).unwrap(); } } @@ -692,9 +702,10 @@ impl<'a> InMemoryStore for IrisLoader<'a> { fn fake_db(&mut self, size: usize) { *self.db_size = size; let iris = Arc::new(GaloisRingSharedIris::default_for_party(self.party_id)); - for side in &mut self.irises { + for side in &self.iris_pools { for i in 0..size { - side.insert(VectorId::from_serial_id(i as u32), iris.clone()); + side.insert(VectorId::from_serial_id(i as u32), iris.clone()) + .unwrap(); } } } @@ -845,6 +856,70 @@ impl From for HawkRequest { } impl HawkRequest { + async fn numa_realloc(self, workers: BothEyes) -> Self { + // TODO: Result + let start = Instant::now(); + + let (queries, queries_mirror) = join!( + Self::numa_realloc_orient(self.queries, &workers), + Self::numa_realloc_orient(self.queries_mirror, &workers) + ); + + metrics::histogram!("numa_realloc_duration").record(start.elapsed().as_secs_f64()); + Self { + batch: self.batch, + queries, + queries_mirror, + ids: self.ids, + } + } + + async fn numa_realloc_orient( + queries: SearchQueries, + workers: &BothEyes, + ) -> SearchQueries { + let (left, right) = join!( + Self::numa_realloc_side(&queries[LEFT], &workers[LEFT]), + Self::numa_realloc_side(&queries[RIGHT], &workers[RIGHT]) + ); + Arc::new([left, right]) + } + + async fn numa_realloc_side( + requests: &VecRequests>, + worker: &IrisPoolHandle, + ) -> VecRequests> { + // Iterate over all the irises. + let all_irises_iter = requests.iter().flat_map(|rots| { + rots.iter() + .flat_map(|query| [&query.iris, &query.iris_proc]) + }); + + // Go realloc the irises in parallel. + let tasks = all_irises_iter.map(|iris| worker.numa_realloc(iris.clone()).unwrap()); + + // Iterate over the results in the same order. + let mut new_irises_iter = try_join_all(tasks).await.unwrap().into_iter(); + + // Rebuild the same structure with the new irises. + let new_requests = requests + .iter() + .map(|rots| { + rots.iter() + .map(|_old_query| { + let iris = new_irises_iter.next().unwrap(); + let iris_proc = new_irises_iter.next().unwrap(); + Aby3Query { iris, iris_proc } + }) + .collect_vec() + .into() + }) + .collect_vec(); + + assert!(new_irises_iter.next().is_none()); + new_requests + } + fn request_types( &self, iris_store: &Aby3SharedIrises, @@ -1313,6 +1388,10 @@ impl HawkHandle { tracing::info!("Processing an Hawk job…"); let now = Instant::now(); + let request = request + .numa_realloc(hawk_actor.workers_handle.clone()) + .await; + // Deletions. apply_deletions(hawk_actor, &request).await?; @@ -1951,6 +2030,7 @@ mod tests_db { hnsw_param_M: 256, hnsw_param_ef_search: 256, hnsw_prf_key: None, + numa: true, match_distances_buffer_size: 64, n_buckets: 10, disable_persistence: false, diff --git a/iris-mpc-cpu/src/execution/hawk_main/iris_worker.rs b/iris-mpc-cpu/src/execution/hawk_main/iris_worker.rs index ef68bafa5..40df53a81 100644 --- a/iris-mpc-cpu/src/execution/hawk_main/iris_worker.rs +++ b/iris-mpc-cpu/src/execution/hawk_main/iris_worker.rs @@ -9,6 +9,7 @@ use crate::{ use core_affinity::CoreId; use crossbeam::channel::{Receiver, Sender}; use eyre::Result; +use futures::future::try_join_all; use iris_mpc_common::{fast_metrics::FastHistogram, vector_id::VectorId}; use std::{ cmp, @@ -23,6 +24,21 @@ use tracing::info; #[derive(Debug)] enum IrisTask { + Sync { + rsp: oneshot::Sender<()>, + }, + /// Move an iris code to memory closer to the pool (NUMA-awareness). + Realloc { + iris: ArcIris, + rsp: oneshot::Sender, + }, + Insert { + vector_id: VectorId, + iris: ArcIris, + }, + Reserve { + additional: usize, + }, DotProductPairs { pairs: Vec<(ArcIris, VectorId)>, rsp: oneshot::Sender>>, @@ -40,12 +56,41 @@ enum IrisTask { #[derive(Clone, Debug)] pub struct IrisPoolHandle { - workers: Arc>>, + workers: Arc<[Sender]>, next_counter: Arc, metric_latency: FastHistogram, } impl IrisPoolHandle { + pub fn numa_realloc(&self, iris: ArcIris) -> Result> { + let (tx, rx) = oneshot::channel(); + let task = IrisTask::Realloc { iris, rsp: tx }; + self.get_next_worker().send(task)?; + Ok(rx) + } + + pub async fn wait_completion(&self) -> Result<()> { + try_join_all(self.workers.iter().map(|w| { + let (rsp, rx) = oneshot::channel(); + w.send(IrisTask::Sync { rsp }).unwrap(); + rx + })) + .await?; + Ok(()) + } + + pub fn insert(&self, vector_id: VectorId, iris: ArcIris) -> Result<()> { + let task = IrisTask::Insert { vector_id, iris }; + self.get_mut_worker().send(task)?; + Ok(()) + } + + pub fn reserve(&self, additional: usize) -> Result<()> { + let task = IrisTask::Reserve { additional }; + self.get_mut_worker().send(task)?; + Ok(()) + } + pub async fn dot_product_pairs( &mut self, pairs: Vec<(ArcIris, VectorId)>, @@ -85,7 +130,7 @@ impl IrisPoolHandle { ) -> Result>> { let start = Instant::now(); - let _ = self.get_next_worker().send(task); + self.get_next_worker().send(task)?; let res = rx.await?; self.metric_latency.record(start.elapsed().as_secs_f64()); @@ -98,9 +143,18 @@ impl IrisPoolHandle { let idx = idx % self.workers.len(); &self.workers[idx] } + + /// Get the worker responsible for store mutations. + fn get_mut_worker(&self) -> &Sender { + &self.workers[0] + } } -pub fn init_workers(shard_index: usize, iris_store: SharedIrisesRef) -> IrisPoolHandle { +pub fn init_workers( + shard_index: usize, + iris_store: SharedIrisesRef, + numa: bool, +) -> IrisPoolHandle { let core_ids = select_core_ids(shard_index); info!( "Dot product shard {} running on {} cores ({:?})", @@ -116,20 +170,51 @@ pub fn init_workers(shard_index: usize, iris_store: SharedIrisesRef) -> let iris_store = iris_store.clone(); std::thread::spawn(move || { let _ = core_affinity::set_for_current(core_id); - worker_thread(rx, iris_store); + worker_thread(rx, iris_store, numa); }); } IrisPoolHandle { - workers: Arc::new(channels), + workers: channels.into(), next_counter: Arc::new(AtomicU64::new(0)), metric_latency: FastHistogram::new("iris_worker.latency"), } } -fn worker_thread(ch: Receiver, iris_store: SharedIrisesRef) { +fn worker_thread(ch: Receiver, iris_store: SharedIrisesRef, numa: bool) { while let Ok(task) = ch.recv() { match task { + IrisTask::Realloc { iris, rsp } => { + // Re-allocate from this thread. + // This attempts to use the NUMA-aware first-touch policy of the OS. + let new_iris = if numa { + Arc::new((*iris).clone()) + } else { + iris + }; + let _ = rsp.send(new_iris); + } + + IrisTask::Sync { rsp } => { + let _ = rsp.send(()); + } + + IrisTask::Insert { vector_id, iris } => { + let iris = if numa { + Arc::new((*iris).clone()) + } else { + iris + }; + + let mut store = iris_store.data.blocking_write(); + store.insert(vector_id, iris); + } + + IrisTask::Reserve { additional } => { + let mut store = iris_store.data.blocking_write(); + store.reserve(additional); + } + IrisTask::DotProductPairs { pairs, rsp } => { let store = iris_store.data.blocking_read(); @@ -167,7 +252,8 @@ fn worker_thread(ch: Receiver, iris_store: SharedIrisesRef) { const SHARD_COUNT: usize = 2; pub fn select_core_ids(shard_index: usize) -> Vec { - let core_ids = core_affinity::get_core_ids().unwrap(); + let mut core_ids = core_affinity::get_core_ids().unwrap(); + core_ids.sort(); assert!(!core_ids.is_empty()); let shard_count = cmp::min(SHARD_COUNT, core_ids.len()); diff --git a/iris-mpc-cpu/src/hawkers/aby3/test_utils.rs b/iris-mpc-cpu/src/hawkers/aby3/test_utils.rs index b5b9f622d..d92f7d543 100644 --- a/iris-mpc-cpu/src/hawkers/aby3/test_utils.rs +++ b/iris-mpc-cpu/src/hawkers/aby3/test_utils.rs @@ -76,7 +76,7 @@ pub async fn setup_local_aby3_players_with_preloaded_db( .into_iter() .zip(storages.into_iter()) .map(|(session, storage)| { - let workers = iris_worker::init_workers(0, storage.clone()); + let workers = iris_worker::init_workers(0, storage.clone(), true); Ok(Arc::new(Mutex::new(Aby3Store { session, storage, @@ -93,7 +93,7 @@ pub async fn setup_local_store_aby3_players(network_t: NetworkType) -> Result Result<()> { match_distances_buffer_size: 64, n_buckets: 10, tls: None, + numa: true, }; let args1 = HawkArgs { party_index: 1, diff --git a/iris-mpc-upgrade-hawk/src/genesis/mod.rs b/iris-mpc-upgrade-hawk/src/genesis/mod.rs index 2dd2f083d..40df7b896 100644 --- a/iris-mpc-upgrade-hawk/src/genesis/mod.rs +++ b/iris-mpc-upgrade-hawk/src/genesis/mod.rs @@ -888,6 +888,7 @@ async fn get_hawk_actor(config: &Config) -> Result { match_distances_buffer_size: config.match_distances_buffer_size, n_buckets: config.n_buckets, tls: config.tls.clone(), + numa: config.hawk_numa, }; log_info(format!( @@ -1245,6 +1246,8 @@ async fn init_graph_from_stores( .await .expect("Failed to load DB"); + iris_loader.wait_completion().await?; + graph_loader .load_graph_store(&graph_store, graph_db_parallelism) .await?; diff --git a/iris-mpc/src/server/mod.rs b/iris-mpc/src/server/mod.rs index db309e37b..2c6bd51b7 100644 --- a/iris-mpc/src/server/mod.rs +++ b/iris-mpc/src/server/mod.rs @@ -413,6 +413,7 @@ async fn init_hawk_actor(config: &Config) -> Result { match_distances_buffer_size: config.match_distances_buffer_size, n_buckets: config.n_buckets, tls: config.tls.clone(), + numa: config.hawk_numa, }; tracing::info!( @@ -439,6 +440,7 @@ async fn load_database( // TODO: not needed? if config.fake_db_size > 0 { iris_loader.fake_db(config.fake_db_size); + iris_loader.wait_completion().await?; return Ok(()); } @@ -457,14 +459,20 @@ async fn load_database( let store_len = iris_store.count_irises().await?; let now = Instant::now(); - let iris_load_future = load_iris_db( - &mut iris_loader, - iris_store, - store_len, - parallelism, - config, - download_shutdown_handler, - ); + + let iris_load_future = async move { + load_iris_db( + &mut iris_loader, + iris_store, + store_len, + parallelism, + config, + download_shutdown_handler, + ) + .await?; + iris_loader.wait_completion().await?; + eyre::Result::<()>::Ok(()) + }; let graph_load_future = graph_loader.load_graph_store( graph_store,