Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions iris-mpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ async fn run_main_server_loop(
// This batch can consist of N sets of iris_share + mask
// It also includes a vector of request ids, mapping to the sets above

let mut batch_stream = receive_batch_stream(
let (mut batch_stream, sem) = receive_batch_stream(
party_id,
aws_clients.sqs_client.clone(),
aws_clients.sns_client.clone(),
Expand All @@ -579,7 +579,10 @@ async fn run_main_server_loop(
batch_sync_shared_state.clone(),
);

current_batch_id_atomic.fetch_add(1, Ordering::SeqCst);
loop {
// Increment batch_id for the next batch, we start at 1, since the initial state for the batch sync is set to 0, which we consider to be invalid

let now = Instant::now();

let mut batch = match batch_stream.recv().await {
Expand Down Expand Up @@ -634,6 +637,11 @@ async fn run_main_server_loop(

task_monitor.check_tasks();

// we are done with the batch sync, so we can release the semaphore permit
// This will allow the next batch to be received
current_batch_id_atomic.fetch_add(1, Ordering::SeqCst);
sem.add_permits(1);

let result_future = hawk_handle.submit_batch_query(batch.clone());

// await the result
Expand All @@ -642,9 +650,6 @@ async fn run_main_server_loop(
.map_err(|e| eyre!("HawkActor processing timeout: {:?}", e))??;
tx_results.send(result).await?;

// Increment batch_id for the next batch
current_batch_id_atomic.fetch_add(1, Ordering::SeqCst);

shutdown_handler.increment_batches_pending_completion()
// wrap up tracing span context
}
Expand Down
100 changes: 64 additions & 36 deletions iris-mpc/src/services/processors/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,44 +55,62 @@ pub fn receive_batch_stream(
current_batch_id_atomic: Arc<AtomicU64>,
iris_store: Store,
batch_sync_shared_state: Arc<tokio::sync::Mutex<BatchSyncSharedState>>,
) -> Receiver<Result<Option<BatchQuery>, ReceiveRequestError>> {
) -> (
Receiver<Result<Option<BatchQuery>, ReceiveRequestError>>,
Arc<Semaphore>,
) {
let (tx, rx) = mpsc::channel(1);
let sem = Arc::new(Semaphore::new(1));

tokio::spawn({
let sem = sem.clone();
async move {
loop {
match sem.acquire().await {
// We successfully acquired the semaphore, proceed with receiving a batch
// However, we forget the permit here to avoid giving it back
// The main server loop will add new permits when allowed
Ok(p) => p.forget(),
Err(_) => {
break;
}
};
let permit = match tx.reserve().await {
Ok(permit) => permit,
Err(_) => {
break;
}
};

let batch = receive_batch(
party_id,
&client,
&sns_client,
&s3_client,
&config,
shares_encryption_key_pairs.clone(),
&shutdown_handler,
&uniqueness_error_result_attributes,
&reauth_error_result_attributes,
&reset_error_result_attributes,
current_batch_id_atomic.clone(),
&iris_store,
batch_sync_shared_state.clone(),
)
.await;

tokio::spawn(async move {
loop {
let permit = match tx.reserve().await {
Ok(p) => p,
Err(_) => break,
};

let batch = receive_batch(
party_id,
&client,
&sns_client,
&s3_client,
&config,
shares_encryption_key_pairs.clone(),
&shutdown_handler,
&uniqueness_error_result_attributes,
&reauth_error_result_attributes,
&reset_error_result_attributes,
current_batch_id_atomic.clone(),
&iris_store,
batch_sync_shared_state.clone(),
)
.await;

let stop = matches!(batch, Err(_) | Ok(None));
permit.send(batch);
let stop = matches!(batch, Err(_) | Ok(None));
permit.send(batch);

if stop {
break;
if stop {
break;
}
}
tracing::info!("Stopping batch receiver.");
}
tracing::info!("Stopping batch receiver.");
});

rx
(rx, sem)
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -198,15 +216,25 @@ impl<'a> BatchProcessor<'a> {
let current_batch_id = self.current_batch_id_atomic.load(Ordering::SeqCst);

// Determine the number of messages to poll based on synchronized state
let own_state = get_own_batch_sync_state(self.config, self.client, current_batch_id)
.await
.map_err(ReceiveRequestError::BatchSyncError)?;
let mut own_state =
get_own_batch_sync_state(self.config, self.client, current_batch_id)
.await
.map_err(ReceiveRequestError::BatchSyncError)?;

// Update the shared state with our current state
{
let mut shared_state = self.batch_sync_shared_state.lock().await;
shared_state.batch_id = own_state.batch_id;
shared_state.messages_to_poll = own_state.messages_to_poll;
// we are here for the first time, set everything
if shared_state.batch_id != own_state.batch_id {
shared_state.batch_id = own_state.batch_id;
shared_state.messages_to_poll = own_state.messages_to_poll;
} else if shared_state.messages_to_poll == 0 {
// we have been here before, only update messages_to_poll if it was 0, otherwise other parties could have state mismatches
shared_state.messages_to_poll = own_state.messages_to_poll;
} else {
// we have already set this for this batch, so it might have gone out to other parties, so we need to update our own state to match what we already sent out
own_state.messages_to_poll = shared_state.messages_to_poll;
}
tracing::info!(
"Updated shared batch sync state: batch_id={}, messages_to_poll={}",
shared_state.batch_id,
Expand Down
Loading