diff --git a/iris-mpc/src/server/mod.rs b/iris-mpc/src/server/mod.rs index 4ee1b9763..7b3d8f8c9 100644 --- a/iris-mpc/src/server/mod.rs +++ b/iris-mpc/src/server/mod.rs @@ -563,7 +563,7 @@ async fn run_main_server_loop( // This batch can consist of N sets of iris_share + mask // It also includes a vector of request ids, mapping to the sets above - let mut batch_stream = receive_batch_stream( + let (mut batch_stream, sem) = receive_batch_stream( party_id, aws_clients.sqs_client.clone(), aws_clients.sns_client.clone(), @@ -579,7 +579,10 @@ async fn run_main_server_loop( batch_sync_shared_state.clone(), ); + current_batch_id_atomic.fetch_add(1, Ordering::SeqCst); loop { + // Increment batch_id for the next batch, we start at 1, since the initial state for the batch sync is set to 0, which we consider to be invalid + let now = Instant::now(); let mut batch = match batch_stream.recv().await { @@ -634,6 +637,11 @@ async fn run_main_server_loop( task_monitor.check_tasks(); + // we are done with the batch sync, so we can release the semaphore permit + // This will allow the next batch to be received + current_batch_id_atomic.fetch_add(1, Ordering::SeqCst); + sem.add_permits(1); + let result_future = hawk_handle.submit_batch_query(batch.clone()); // await the result @@ -642,9 +650,6 @@ async fn run_main_server_loop( .map_err(|e| eyre!("HawkActor processing timeout: {:?}", e))??; tx_results.send(result).await?; - // Increment batch_id for the next batch - current_batch_id_atomic.fetch_add(1, Ordering::SeqCst); - shutdown_handler.increment_batches_pending_completion() // wrap up tracing span context } diff --git a/iris-mpc/src/services/processors/batch.rs b/iris-mpc/src/services/processors/batch.rs index f9d1b1473..5818b44b0 100644 --- a/iris-mpc/src/services/processors/batch.rs +++ b/iris-mpc/src/services/processors/batch.rs @@ -55,44 +55,62 @@ pub fn receive_batch_stream( current_batch_id_atomic: Arc, iris_store: Store, batch_sync_shared_state: Arc>, -) -> Receiver, ReceiveRequestError>> { +) -> ( + Receiver, ReceiveRequestError>>, + Arc, +) { let (tx, rx) = mpsc::channel(1); + let sem = Arc::new(Semaphore::new(1)); + + tokio::spawn({ + let sem = sem.clone(); + async move { + loop { + match sem.acquire().await { + // We successfully acquired the semaphore, proceed with receiving a batch + // However, we forget the permit here to avoid giving it back + // The main server loop will add new permits when allowed + Ok(p) => p.forget(), + Err(_) => { + break; + } + }; + let permit = match tx.reserve().await { + Ok(permit) => permit, + Err(_) => { + break; + } + }; + + let batch = receive_batch( + party_id, + &client, + &sns_client, + &s3_client, + &config, + shares_encryption_key_pairs.clone(), + &shutdown_handler, + &uniqueness_error_result_attributes, + &reauth_error_result_attributes, + &reset_error_result_attributes, + current_batch_id_atomic.clone(), + &iris_store, + batch_sync_shared_state.clone(), + ) + .await; - tokio::spawn(async move { - loop { - let permit = match tx.reserve().await { - Ok(p) => p, - Err(_) => break, - }; - - let batch = receive_batch( - party_id, - &client, - &sns_client, - &s3_client, - &config, - shares_encryption_key_pairs.clone(), - &shutdown_handler, - &uniqueness_error_result_attributes, - &reauth_error_result_attributes, - &reset_error_result_attributes, - current_batch_id_atomic.clone(), - &iris_store, - batch_sync_shared_state.clone(), - ) - .await; - - let stop = matches!(batch, Err(_) | Ok(None)); - permit.send(batch); + let stop = matches!(batch, Err(_) | Ok(None)); + permit.send(batch); - if stop { - break; + if stop { + break; + } } + tracing::info!("Stopping batch receiver."); } - tracing::info!("Stopping batch receiver."); }); - rx + (rx, sem) } #[allow(clippy::too_many_arguments)] @@ -198,15 +216,25 @@ impl<'a> BatchProcessor<'a> { let current_batch_id = self.current_batch_id_atomic.load(Ordering::SeqCst); // Determine the number of messages to poll based on synchronized state - let own_state = get_own_batch_sync_state(self.config, self.client, current_batch_id) - .await - .map_err(ReceiveRequestError::BatchSyncError)?; + let mut own_state = + get_own_batch_sync_state(self.config, self.client, current_batch_id) + .await + .map_err(ReceiveRequestError::BatchSyncError)?; // Update the shared state with our current state { let mut shared_state = self.batch_sync_shared_state.lock().await; - shared_state.batch_id = own_state.batch_id; - shared_state.messages_to_poll = own_state.messages_to_poll; + // we are here for the first time, set everything + if shared_state.batch_id != own_state.batch_id { + shared_state.batch_id = own_state.batch_id; + shared_state.messages_to_poll = own_state.messages_to_poll; + } else if shared_state.messages_to_poll == 0 { + // we have been here before, only update messages_to_poll if it was 0, otherwise other parties could have state mismatches + shared_state.messages_to_poll = own_state.messages_to_poll; + } else { + // we have already set this for this batch, so it might have gone out to other parties, so we need to update our own state to match what we already sent out + own_state.messages_to_poll = shared_state.messages_to_poll; + } tracing::info!( "Updated shared batch sync state: batch_id={}, messages_to_poll={}", shared_state.batch_id,