Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2bb29fc
au/iris-workers: Move dot products onto a worker pool
naure Aug 20, 2025
676482f
au/iris-workers: Refactor pairwise_distance and add metrics
naure Aug 20, 2025
25d6012
Merge branch 'main' into au/iris-workers-store
naure Aug 21, 2025
a4baabe
Merge branch 'main' into au/iris-workers
naure Aug 21, 2025
4410597
Merge branch 'au/iris-workers' into au/iris-workers-store
naure Aug 21, 2025
31a60e2
au/iris-workers: Fix benchmark
naure Aug 21, 2025
c354f67
Merge branch 'au/iris-workers' into au/iris-workers-store
naure Aug 21, 2025
7936720
au/iris-workers-store: Draft moving store mutations to pools
naure Aug 21, 2025
c5ad064
Merge branch 'main' into au/iris-workers
naure Aug 21, 2025
c6272bb
Merge branch 'au/iris-workers' into au/iris-workers-store
naure Aug 21, 2025
47aca51
Merge branch 'main' into au/iris-workers
naure Aug 27, 2025
f762993
au/iris-workers: rename config
naure Aug 27, 2025
e2c6ddf
au/iris-workers: support single core
naure Aug 27, 2025
d70671e
au/iris-workers: Remove config
naure Aug 27, 2025
3190c4f
Merge branch 'au/iris-workers' into au/iris-workers-store
naure Aug 27, 2025
3a27d36
au/iris-workers-store: NUMA-aware allocation of irises
naure Aug 27, 2025
7253946
Merge branch 'main' into au/iris-workers
naure Aug 28, 2025
ef3f710
Merge branch 'au/iris-workers' into au/iris-workers-store
naure Aug 28, 2025
31a8b29
Merge branch 'main' into au/iris-workers-store
naure Aug 28, 2025
e6b8990
au/iris-workers-store: Rename Realloc
naure Aug 28, 2025
31cb947
au/iris-workers-store: Optimize with Arc<[]>
naure Aug 28, 2025
93e0852
au/iris-workers-store: NUMA-aware allocation of new irises
naure Aug 28, 2025
a4d35a2
au/iris-workers-store: metric numa_realloc
naure Aug 28, 2025
4ae421e
Merge branch 'main' into au/iris-workers-store
naure Sep 3, 2025
39f0e05
Merge branch 'main' into au/iris-workers-store
naure Sep 8, 2025
0d21bfb
Merge branch 'main' into au/iris-workers-store
naure Sep 11, 2025
0793186
au/iris-workers-store: Mut worker. Wait for completion of all inserts.
naure Sep 11, 2025
e552861
au/iris-workers-store: Remove unused return channel
naure Sep 11, 2025
a661bc4
au/iris-workers-store: HAWK_NUMA config
naure Sep 11, 2025
1cbc75d
Merge branch 'main' into au/iris-workers-store
naure Sep 11, 2025
3755394
au/iris-workers-store: Fix after merge
naure Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ pub struct Config {
#[serde(default)]
pub hawk_prf_key: Option<u64>,

#[serde(default = "default_hawk_numa")]
pub hawk_numa: bool,

#[serde(default = "default_max_deletions_per_batch")]
pub max_deletions_per_batch: usize,

Expand Down Expand Up @@ -399,6 +402,10 @@ fn default_hnsw_param_ef_search() -> usize {
256
}

fn default_hawk_numa() -> bool {
true
}

fn default_service_ports() -> Vec<String> {
vec!["4000".to_string(); 3]
}
Expand Down Expand Up @@ -762,6 +769,7 @@ impl From<Config> for CommonConfig {
hnsw_param_M,
hnsw_param_ef_search,
hawk_prf_key,
hawk_numa: _, // could be different for each server
max_deletions_per_batch,
enable_modifications_sync,
enable_modifications_replay,
Expand Down
110 changes: 95 additions & 15 deletions iris-mpc-cpu/src/execution/hawk_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
};
use clap::Parser;
use eyre::{eyre, Report, Result};
use futures::try_join;
use futures::{future::try_join_all, try_join};
use intra_batch::intra_batch_is_match;
use iris_mpc_common::job::Eye;
use iris_mpc_common::{
Expand Down Expand Up @@ -125,6 +125,9 @@ pub struct HawkArgs {

#[clap(flatten)]
pub tls: Option<TlsConfig>,

#[clap(long, default_value_t = false)]
pub numa: bool,
}

/// HawkActor manages the state of the HNSW database and connections to other
Expand Down Expand Up @@ -326,8 +329,8 @@ impl HawkActor {
build_network_handle(args, &identities, SessionGroups::N_SESSIONS_PER_REQUEST).await?;
let graph_store = graph.map(GraphMem::to_arc);
let iris_store = iris_store.map(SharedIrises::to_arc);
let workers_handle =
[LEFT, RIGHT].map(|side| iris_worker::init_workers(side, iris_store[side].clone()));
let workers_handle = [LEFT, RIGHT]
.map(|side| iris_worker::init_workers(side, iris_store[side].clone(), args.numa));

let bucket_statistics_left = BucketStatistics::new(
args.match_distances_buffer_size,
Expand Down Expand Up @@ -625,10 +628,7 @@ impl HawkActor {
IrisLoader {
party_id: self.party_id,
db_size: &mut self.loader_db_size,
irises: [
self.iris_store[0].write().await,
self.iris_store[1].write().await,
],
iris_pools: self.workers_handle.clone(),
},
GraphLoader([
self.graph_store[0].write().await,
Expand All @@ -650,7 +650,17 @@ pub type Aby3SharedIrisesMut<'a> = RwLockWriteGuard<'a, Aby3SharedIrises>;
pub struct IrisLoader<'a> {
party_id: usize,
db_size: &'a mut usize,
irises: BothEyes<Aby3SharedIrisesMut<'a>>,
iris_pools: BothEyes<IrisPoolHandle>,
}

impl IrisLoader<'_> {
pub async fn wait_completion(self) -> Result<()> {
try_join!(
self.iris_pools[LEFT].wait_completion(),
self.iris_pools[RIGHT].wait_completion(),
)?;
Ok(())
}
}

#[allow(clippy::needless_lifetimes)]
Expand All @@ -664,14 +674,14 @@ impl<'a> InMemoryStore for IrisLoader<'a> {
right_code: &[u16],
right_mask: &[u16],
) {
for (side, code, mask) in izip!(
&mut self.irises,
for (pool, code, mask) in izip!(
&self.iris_pools,
[left_code, right_code],
[left_mask, right_mask]
) {
let iris = GaloisRingSharedIris::try_from_buffers(self.party_id, code, mask)
.expect("Wrong code or mask size");
side.insert(vector_id, iris);
pool.insert(vector_id, iris).unwrap();
}
}

Expand All @@ -680,8 +690,8 @@ impl<'a> InMemoryStore for IrisLoader<'a> {
}

fn reserve(&mut self, additional: usize) {
for side in &mut self.irises {
side.reserve(additional);
for side in &self.iris_pools {
side.reserve(additional).unwrap();
}
}

Expand All @@ -692,9 +702,10 @@ impl<'a> InMemoryStore for IrisLoader<'a> {
fn fake_db(&mut self, size: usize) {
*self.db_size = size;
let iris = Arc::new(GaloisRingSharedIris::default_for_party(self.party_id));
for side in &mut self.irises {
for side in &self.iris_pools {
for i in 0..size {
side.insert(VectorId::from_serial_id(i as u32), iris.clone());
side.insert(VectorId::from_serial_id(i as u32), iris.clone())
.unwrap();
}
}
}
Expand Down Expand Up @@ -845,6 +856,70 @@ impl From<BatchQuery> for HawkRequest {
}

impl HawkRequest {
async fn numa_realloc(self, workers: BothEyes<IrisPoolHandle>) -> Self {
// TODO: Result<Self>
let start = Instant::now();

let (queries, queries_mirror) = join!(
Self::numa_realloc_orient(self.queries, &workers),
Self::numa_realloc_orient(self.queries_mirror, &workers)
);

metrics::histogram!("numa_realloc_duration").record(start.elapsed().as_secs_f64());
Self {
batch: self.batch,
queries,
queries_mirror,
ids: self.ids,
}
}

async fn numa_realloc_orient(
queries: SearchQueries,
workers: &BothEyes<IrisPoolHandle>,
) -> SearchQueries {
let (left, right) = join!(
Self::numa_realloc_side(&queries[LEFT], &workers[LEFT]),
Self::numa_realloc_side(&queries[RIGHT], &workers[RIGHT])
);
Arc::new([left, right])
}

async fn numa_realloc_side(
requests: &VecRequests<VecRots<Aby3Query>>,
worker: &IrisPoolHandle,
) -> VecRequests<VecRots<Aby3Query>> {
// Iterate over all the irises.
let all_irises_iter = requests.iter().flat_map(|rots| {
rots.iter()
.flat_map(|query| [&query.iris, &query.iris_proc])
});

// Go realloc the irises in parallel.
let tasks = all_irises_iter.map(|iris| worker.numa_realloc(iris.clone()).unwrap());

// Iterate over the results in the same order.
let mut new_irises_iter = try_join_all(tasks).await.unwrap().into_iter();

// Rebuild the same structure with the new irises.
let new_requests = requests
.iter()
.map(|rots| {
rots.iter()
.map(|_old_query| {
let iris = new_irises_iter.next().unwrap();
let iris_proc = new_irises_iter.next().unwrap();
Aby3Query { iris, iris_proc }
})
.collect_vec()
.into()
})
.collect_vec();

assert!(new_irises_iter.next().is_none());
new_requests
}

fn request_types(
&self,
iris_store: &Aby3SharedIrises,
Expand Down Expand Up @@ -1313,6 +1388,10 @@ impl HawkHandle {
tracing::info!("Processing an Hawk job…");
let now = Instant::now();

let request = request
.numa_realloc(hawk_actor.workers_handle.clone())
.await;

// Deletions.
apply_deletions(hawk_actor, &request).await?;

Expand Down Expand Up @@ -1951,6 +2030,7 @@ mod tests_db {
hnsw_param_M: 256,
hnsw_param_ef_search: 256,
hnsw_prf_key: None,
numa: true,
match_distances_buffer_size: 64,
n_buckets: 10,
disable_persistence: false,
Expand Down
100 changes: 93 additions & 7 deletions iris-mpc-cpu/src/execution/hawk_main/iris_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
use core_affinity::CoreId;
use crossbeam::channel::{Receiver, Sender};
use eyre::Result;
use futures::future::try_join_all;
use iris_mpc_common::{fast_metrics::FastHistogram, vector_id::VectorId};
use std::{
cmp,
Expand All @@ -23,6 +24,21 @@ use tracing::info;

#[derive(Debug)]
enum IrisTask {
Sync {
rsp: oneshot::Sender<()>,
},
/// Move an iris code to memory closer to the pool (NUMA-awareness).
Realloc {
iris: ArcIris,
rsp: oneshot::Sender<ArcIris>,
},
Insert {
vector_id: VectorId,
iris: ArcIris,
},
Reserve {
additional: usize,
},
DotProductPairs {
pairs: Vec<(ArcIris, VectorId)>,
rsp: oneshot::Sender<Vec<RingElement<u16>>>,
Expand All @@ -40,12 +56,41 @@ enum IrisTask {

#[derive(Clone, Debug)]
pub struct IrisPoolHandle {
workers: Arc<Vec<Sender<IrisTask>>>,
workers: Arc<[Sender<IrisTask>]>,
next_counter: Arc<AtomicU64>,
metric_latency: FastHistogram,
}

impl IrisPoolHandle {
pub fn numa_realloc(&self, iris: ArcIris) -> Result<oneshot::Receiver<ArcIris>> {
let (tx, rx) = oneshot::channel();
let task = IrisTask::Realloc { iris, rsp: tx };
self.get_next_worker().send(task)?;
Ok(rx)
}

pub async fn wait_completion(&self) -> Result<()> {
try_join_all(self.workers.iter().map(|w| {
let (rsp, rx) = oneshot::channel();
w.send(IrisTask::Sync { rsp }).unwrap();
rx
}))
.await?;
Ok(())
}

pub fn insert(&self, vector_id: VectorId, iris: ArcIris) -> Result<()> {
let task = IrisTask::Insert { vector_id, iris };
self.get_mut_worker().send(task)?;
Ok(())
}

pub fn reserve(&self, additional: usize) -> Result<()> {
let task = IrisTask::Reserve { additional };
self.get_mut_worker().send(task)?;
Ok(())
}

pub async fn dot_product_pairs(
&mut self,
pairs: Vec<(ArcIris, VectorId)>,
Expand Down Expand Up @@ -85,7 +130,7 @@ impl IrisPoolHandle {
) -> Result<Vec<RingElement<u16>>> {
let start = Instant::now();

let _ = self.get_next_worker().send(task);
self.get_next_worker().send(task)?;
let res = rx.await?;

self.metric_latency.record(start.elapsed().as_secs_f64());
Expand All @@ -98,9 +143,18 @@ impl IrisPoolHandle {
let idx = idx % self.workers.len();
&self.workers[idx]
}

/// Get the worker responsible for store mutations.
fn get_mut_worker(&self) -> &Sender<IrisTask> {
&self.workers[0]
}
}

pub fn init_workers(shard_index: usize, iris_store: SharedIrisesRef<ArcIris>) -> IrisPoolHandle {
pub fn init_workers(
shard_index: usize,
iris_store: SharedIrisesRef<ArcIris>,
numa: bool,
) -> IrisPoolHandle {
let core_ids = select_core_ids(shard_index);
info!(
"Dot product shard {} running on {} cores ({:?})",
Expand All @@ -116,20 +170,51 @@ pub fn init_workers(shard_index: usize, iris_store: SharedIrisesRef<ArcIris>) ->
let iris_store = iris_store.clone();
std::thread::spawn(move || {
let _ = core_affinity::set_for_current(core_id);
worker_thread(rx, iris_store);
worker_thread(rx, iris_store, numa);
});
}

IrisPoolHandle {
workers: Arc::new(channels),
workers: channels.into(),
next_counter: Arc::new(AtomicU64::new(0)),
metric_latency: FastHistogram::new("iris_worker.latency"),
}
}

fn worker_thread(ch: Receiver<IrisTask>, iris_store: SharedIrisesRef<ArcIris>) {
fn worker_thread(ch: Receiver<IrisTask>, iris_store: SharedIrisesRef<ArcIris>, numa: bool) {
while let Ok(task) = ch.recv() {
match task {
IrisTask::Realloc { iris, rsp } => {
// Re-allocate from this thread.
// This attempts to use the NUMA-aware first-touch policy of the OS.
let new_iris = if numa {
Arc::new((*iris).clone())
} else {
iris
};
let _ = rsp.send(new_iris);
}

IrisTask::Sync { rsp } => {
let _ = rsp.send(());
}

IrisTask::Insert { vector_id, iris } => {
let iris = if numa {
Arc::new((*iris).clone())
} else {
iris
};

let mut store = iris_store.data.blocking_write();
store.insert(vector_id, iris);
}

IrisTask::Reserve { additional } => {
let mut store = iris_store.data.blocking_write();
store.reserve(additional);
}

IrisTask::DotProductPairs { pairs, rsp } => {
let store = iris_store.data.blocking_read();

Expand Down Expand Up @@ -167,7 +252,8 @@ fn worker_thread(ch: Receiver<IrisTask>, iris_store: SharedIrisesRef<ArcIris>) {
const SHARD_COUNT: usize = 2;

pub fn select_core_ids(shard_index: usize) -> Vec<CoreId> {
let core_ids = core_affinity::get_core_ids().unwrap();
let mut core_ids = core_affinity::get_core_ids().unwrap();
core_ids.sort();
assert!(!core_ids.is_empty());

let shard_count = cmp::min(SHARD_COUNT, core_ids.len());
Expand Down
Loading
Loading