diff --git a/crates/engine/primitives/src/config.rs b/crates/engine/primitives/src/config.rs index 7d3d9c1c37f..0b77119c457 100644 --- a/crates/engine/primitives/src/config.rs +++ b/crates/engine/primitives/src/config.rs @@ -9,11 +9,16 @@ pub const DEFAULT_MEMORY_BLOCK_BUFFER_TARGET: u64 = 0; /// Default maximum concurrency for on-demand proof tasks (blinded nodes) pub const DEFAULT_MAX_PROOF_TASK_CONCURRENCY: u64 = 256; +/// Minimum number of workers we allow configuring explicitly. +pub const MIN_WORKER_COUNT: usize = 2; + /// Returns the default number of storage worker threads based on available parallelism. fn default_storage_worker_count() -> usize { #[cfg(feature = "std")] { - std::thread::available_parallelism().map(|n| (n.get() * 2).clamp(2, 64)).unwrap_or(8) + std::thread::available_parallelism() + .map(|n| (n.get() * 2).clamp(MIN_WORKER_COUNT, 64)) + .unwrap_or(8) } #[cfg(not(feature = "std"))] { @@ -28,7 +33,7 @@ fn default_storage_worker_count() -> usize { /// so we use higher concurrency (1.5x storage workers) to maximize throughput and overlap. /// While storage workers are CPU-bound, account workers are I/O-bound coordinators. fn default_account_worker_count() -> usize { - (default_storage_worker_count() * 3) / 2 + ((default_storage_worker_count() * 3) / 2).max(MIN_WORKER_COUNT) } /// The size of proof targets chunk to spawn in one multiproof calculation. @@ -493,8 +498,8 @@ impl TreeConfig { } /// Setter for the number of storage proof worker threads. - pub const fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self { - self.storage_worker_count = storage_worker_count; + pub fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self { + self.storage_worker_count = storage_worker_count.max(MIN_WORKER_COUNT); self } @@ -504,8 +509,8 @@ impl TreeConfig { } /// Setter for the number of account proof worker threads. - pub const fn with_account_worker_count(mut self, account_worker_count: usize) -> Self { - self.account_worker_count = account_worker_count; + pub fn with_account_worker_count(mut self, account_worker_count: usize) -> Self { + self.account_worker_count = account_worker_count.max(MIN_WORKER_COUNT); self } } diff --git a/crates/engine/tree/src/tree/payload_processor/mod.rs b/crates/engine/tree/src/tree/payload_processor/mod.rs index c24b0d1fe16..87ca695689d 100644 --- a/crates/engine/tree/src/tree/payload_processor/mod.rs +++ b/crates/engine/tree/src/tree/payload_processor/mod.rs @@ -32,7 +32,7 @@ use reth_provider::{ use reth_revm::{db::BundleState, state::EvmState}; use reth_trie::TrieInput; use reth_trie_parallel::{ - proof_task::{ProofTaskCtx, ProofTaskManager}, + proof_task::{spawn_proof_workers, ProofTaskCtx}, root::ParallelStateRootError, }; use reth_trie_sparse::{ @@ -167,8 +167,7 @@ where /// This returns a handle to await the final state root and to interact with the tasks (e.g. /// canceling) /// - /// Returns an error with the original transactions iterator if the proof task manager fails to - /// initialize. + /// Returns an error with the original transactions iterator if proof worker spawning fails. #[allow(clippy::type_complexity)] pub fn spawn>( &mut self, @@ -204,14 +203,14 @@ where let storage_worker_count = config.storage_worker_count(); let account_worker_count = config.account_worker_count(); let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize; - let proof_task = match ProofTaskManager::new( + let proof_handle = match spawn_proof_workers( self.executor.handle().clone(), consistent_view, task_ctx, storage_worker_count, account_worker_count, ) { - Ok(task) => task, + Ok(handle) => handle, Err(error) => { return Err((error, transactions, env, provider_builder)); } @@ -223,7 +222,7 @@ where let multi_proof_task = MultiProofTask::new( state_root_config, self.executor.clone(), - proof_task.handle(), + proof_handle.clone(), to_sparse_trie, max_multi_proof_task_concurrency, config.multiproof_chunking_enabled().then_some(config.multiproof_chunk_size()), @@ -252,19 +251,7 @@ where let (state_root_tx, state_root_rx) = channel(); // Spawn the sparse trie task using any stored trie and parallel trie configuration. - self.spawn_sparse_trie_task(sparse_trie_rx, proof_task.handle(), state_root_tx); - - // spawn the proof task - self.executor.spawn_blocking(move || { - if let Err(err) = proof_task.run() { - // At least log if there is an error at any point - tracing::error!( - target: "engine::root", - ?err, - "Storage proof task returned an error" - ); - } - }); + self.spawn_sparse_trie_task(sparse_trie_rx, proof_handle, state_root_tx); Ok(PayloadHandle { to_multi_proof, diff --git a/crates/engine/tree/src/tree/payload_processor/multiproof.rs b/crates/engine/tree/src/tree/payload_processor/multiproof.rs index f865312b83d..3d2a7a87da7 100644 --- a/crates/engine/tree/src/tree/payload_processor/multiproof.rs +++ b/crates/engine/tree/src/tree/payload_processor/multiproof.rs @@ -20,7 +20,7 @@ use reth_trie::{ }; use reth_trie_parallel::{ proof::ParallelProof, - proof_task::{AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle}, + proof_task::{AccountMultiproofInput, ProofTaskManagerHandle}, root::ParallelStateRootError, }; use std::{ @@ -346,11 +346,8 @@ pub struct MultiproofManager { pending: VecDeque, /// Executor for tasks executor: WorkloadExecutor, - /// Handle to the proof task manager used for creating `ParallelProof` instances for storage - /// proofs. - storage_proof_task_handle: ProofTaskManagerHandle, - /// Handle to the proof task manager used for account multiproofs. - account_proof_task_handle: ProofTaskManagerHandle, + /// Handle to the proof worker pools (storage and account). + proof_task_handle: ProofTaskManagerHandle, /// Cached storage proof roots for missed leaves; this maps /// hashed (missed) addresses to their storage proof roots. /// @@ -372,8 +369,7 @@ impl MultiproofManager { fn new( executor: WorkloadExecutor, metrics: MultiProofTaskMetrics, - storage_proof_task_handle: ProofTaskManagerHandle, - account_proof_task_handle: ProofTaskManagerHandle, + proof_task_handle: ProofTaskManagerHandle, max_concurrent: usize, ) -> Self { Self { @@ -382,8 +378,7 @@ impl MultiproofManager { executor, inflight: 0, metrics, - storage_proof_task_handle, - account_proof_task_handle, + proof_task_handle, missed_leaves_storage_roots: Default::default(), } } @@ -452,7 +447,7 @@ impl MultiproofManager { multi_added_removed_keys, } = storage_multiproof_input; - let storage_proof_task_handle = self.storage_proof_task_handle.clone(); + let storage_proof_task_handle = self.proof_task_handle.clone(); let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone(); self.executor.spawn_blocking(move || { @@ -524,7 +519,7 @@ impl MultiproofManager { state_root_message_sender, multi_added_removed_keys, } = multiproof_input; - let account_proof_task_handle = self.account_proof_task_handle.clone(); + let account_proof_task_handle = self.proof_task_handle.clone(); let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone(); self.executor.spawn_blocking(move || { @@ -556,15 +551,10 @@ impl MultiproofManager { missed_leaves_storage_roots, }; - let (sender, receiver) = channel(); let proof_result: Result = (|| { - account_proof_task_handle - .queue_task(ProofTaskKind::AccountMultiproof(input, sender)) - .map_err(|_| { - ParallelStateRootError::Other( - "Failed to queue account multiproof to worker pool".into(), - ) - })?; + let receiver = account_proof_task_handle + .queue_account_multiproof(input) + .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; receiver .recv() @@ -713,8 +703,7 @@ impl MultiProofTask { multiproof_manager: MultiproofManager::new( executor, metrics.clone(), - proof_task_handle.clone(), // handle for storage proof workers - proof_task_handle, // handle for account proof workers + proof_task_handle, max_concurrency, ), metrics, @@ -1223,7 +1212,7 @@ mod tests { DatabaseProviderFactory, }; use reth_trie::{MultiProof, TrieInput}; - use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager}; + use reth_trie_parallel::proof_task::{spawn_proof_workers, ProofTaskCtx}; use revm_primitives::{B256, U256}; fn create_test_state_root_task(factory: F) -> MultiProofTask @@ -1238,12 +1227,12 @@ mod tests { config.prefix_sets.clone(), ); let consistent_view = ConsistentDbView::new(factory, None); - let proof_task = - ProofTaskManager::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1) - .expect("Failed to create ProofTaskManager"); + let proof_handle = + spawn_proof_workers(executor.handle().clone(), consistent_view, task_ctx, 1, 1) + .expect("Failed to spawn proof workers"); let channel = channel(); - MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None) + MultiProofTask::new(config, executor, proof_handle, channel.0, 1, None) } #[test] diff --git a/crates/engine/tree/src/tree/payload_validator.rs b/crates/engine/tree/src/tree/payload_validator.rs index 1e63d29bf79..6c08748cdb5 100644 --- a/crates/engine/tree/src/tree/payload_validator.rs +++ b/crates/engine/tree/src/tree/payload_validator.rs @@ -890,13 +890,12 @@ where (handle, StateRootStrategy::StateRootTask) } Err((error, txs, env, provider_builder)) => { - // Failed to initialize proof task manager, fallback to parallel state - // root + // Failed to spawn proof workers, fallback to parallel state root error!( target: "engine::tree", block=?block_num_hash, ?error, - "Failed to initialize proof task manager, falling back to parallel state root" + "Failed to spawn proof workers, falling back to parallel state root" ); ( self.payload_processor.spawn_cache_exclusive( diff --git a/crates/trie/parallel/src/proof.rs b/crates/trie/parallel/src/proof.rs index 7fc1f022a7e..841779aa982 100644 --- a/crates/trie/parallel/src/proof.rs +++ b/crates/trie/parallel/src/proof.rs @@ -1,8 +1,6 @@ use crate::{ metrics::ParallelTrieMetrics, - proof_task::{ - AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle, StorageProofInput, - }, + proof_task::{AccountMultiproofInput, ProofTaskManagerHandle, StorageProofInput}, root::ParallelStateRootError, StorageRootTargets, }; @@ -16,10 +14,7 @@ use reth_trie::{ DecodedMultiProof, DecodedStorageMultiProof, HashedPostStateSorted, MultiProofTargets, Nibbles, }; use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys; -use std::sync::{ - mpsc::{channel, Receiver}, - Arc, -}; +use std::sync::{mpsc::Receiver, Arc}; use tracing::trace; /// Parallel proof calculator. @@ -41,7 +36,7 @@ pub struct ParallelProof { collect_branch_node_masks: bool, /// Provided by the user to give the necessary context to retain extra proofs. multi_added_removed_keys: Option>, - /// Handle to the proof task manager. + /// Handle to the proof worker pools. proof_task_handle: ProofTaskManagerHandle, /// Cached storage proof roots for missed leaves; this maps /// hashed (missed) addresses to their storage proof roots. @@ -93,7 +88,10 @@ impl ParallelProof { hashed_address: B256, prefix_set: PrefixSet, target_slots: B256Set, - ) -> Receiver> { + ) -> Result< + Receiver>, + ParallelStateRootError, + > { let input = StorageProofInput::new( hashed_address, prefix_set, @@ -102,9 +100,9 @@ impl ParallelProof { self.multi_added_removed_keys.clone(), ); - let (sender, receiver) = std::sync::mpsc::channel(); - let _ = self.proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender)); - receiver + self.proof_task_handle + .queue_storage_proof(input) + .map_err(|e| ParallelStateRootError::Other(e.to_string())) } /// Generate a storage multiproof according to the specified targets and hashed address. @@ -124,7 +122,7 @@ impl ParallelProof { "Starting storage proof generation" ); - let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots); + let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots)?; let proof_result = receiver.recv().map_err(|_| { ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other( format!("channel closed for {hashed_address}"), @@ -193,15 +191,10 @@ impl ParallelProof { missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(), }; - let (sender, receiver) = channel(); - self.proof_task_handle - .queue_task(ProofTaskKind::AccountMultiproof(input, sender)) - .map_err(|_| { - ParallelStateRootError::Other( - "Failed to queue account multiproof: account worker pool unavailable" - .to_string(), - ) - })?; + let receiver = self + .proof_task_handle + .queue_account_multiproof(input) + .map_err(|e| ParallelStateRootError::Other(e.to_string()))?; // Wait for account multiproof result from worker let (multiproof, stats) = receiver.recv().map_err(|_| { @@ -231,7 +224,7 @@ impl ParallelProof { #[cfg(test)] mod tests { use super::*; - use crate::proof_task::{ProofTaskCtx, ProofTaskManager}; + use crate::proof_task::{spawn_proof_workers, ProofTaskCtx}; use alloy_primitives::{ keccak256, map::{B256Set, DefaultHashBuilder, HashMap}, @@ -313,13 +306,8 @@ mod tests { let task_ctx = ProofTaskCtx::new(Default::default(), Default::default(), Default::default()); - let proof_task = - ProofTaskManager::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap(); - let proof_task_handle = proof_task.handle(); - - // keep the join handle around to make sure it does not return any errors - // after we compute the state root - let join_handle = rt.spawn_blocking(move || proof_task.run()); + let proof_task_handle = + spawn_proof_workers(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap(); let parallel_result = ParallelProof::new( Default::default(), @@ -354,9 +342,7 @@ mod tests { // then compare the entire thing for any mask differences assert_eq!(parallel_result, sequential_result_decoded); - // drop the handle to terminate the task and then block on the proof task handle to make - // sure it does not return any errors + // Workers shut down automatically when handle is dropped drop(proof_task_handle); - rt.block_on(join_handle).unwrap().expect("The proof task should not return an error"); } } diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index 18062747901..8545b7d2f28 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -1,9 +1,14 @@ -//! A Task that manages sending proof requests to a number of tasks that have longer-running -//! database transactions. +//! Parallel proof computation using worker pools with dedicated database transactions. //! -//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks, -//! and is responsible for managing the fixed number of database transactions created at the start -//! of the task. +//! +//! # Architecture +//! +//! - **Worker Pools**: Pre-spawned workers with dedicated database transactions +//! - Storage pool: Handles storage proofs and blinded storage node requests +//! - Account pool: Handles account multiproofs and blinded account node requests +//! - **Direct Channel Access**: [`ProofTaskManagerHandle`] provides type-safe queue methods with +//! direct access to worker channels, eliminating routing overhead +//! - **Automatic Shutdown**: Workers terminate gracefully when all handles are dropped //! //! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and //! [`HashedPostStateCursorFactory`], which are each backed by a database transaction. @@ -21,7 +26,7 @@ use alloy_rlp::{BufMut, Encodable}; use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender}; use dashmap::DashMap; use reth_db_api::transaction::DbTx; -use reth_execution_errors::SparseTrieError; +use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind}; use reth_provider::{ providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError, ProviderResult, @@ -65,9 +70,6 @@ type AccountMultiproofResult = Result<(DecodedMultiProof, ParallelTrieStats), ParallelStateRootError>; /// Internal message for storage workers. -/// -/// This is NOT exposed publicly. External callers use `ProofTaskKind::StorageProof` or -/// `ProofTaskKind::BlindedStorageNode` which are routed through the manager's `std::mpsc` channel. #[derive(Debug)] enum StorageWorkerJob { /// Storage proof computation request @@ -88,58 +90,80 @@ enum StorageWorkerJob { }, } -/// Manager for coordinating proof request execution across different task types. -/// -/// # Architecture -/// -/// This manager operates two distinct worker pools for parallel trie operations: +/// Spawns storage and account worker pools with dedicated database transactions. /// -/// **Worker Pools**: -/// - Pre-spawned workers with dedicated long-lived transactions -/// - **Storage pool**: Handles `StorageProof` and `BlindedStorageNode` requests -/// - **Account pool**: Handles `AccountMultiproof` and `BlindedAccountNode` requests, delegates -/// storage proof computation to storage pool -/// - Tasks queued via crossbeam unbounded channels -/// - Workers continuously process without transaction overhead -/// - Returns error if worker pool is unavailable (all workers panicked) +/// Returns a handle for submitting proof tasks to the worker pools. +/// Workers run until the last handle is dropped. /// -/// # Public Interface -/// -/// The public interface through `ProofTaskManagerHandle` allows external callers to: -/// - Submit tasks via `queue_task(ProofTaskKind)` -/// - Use standard `std::mpsc` message passing -/// - Receive consistent return types and error handling -#[derive(Debug)] -pub struct ProofTaskManager { - /// Sender for storage worker jobs to worker pool. - storage_work_tx: CrossbeamSender, - - /// Number of storage workers successfully spawned. - /// - /// May be less than requested if concurrency limits reduce the worker budget. +/// # Parameters +/// - `executor`: Tokio runtime handle for spawning blocking tasks +/// - `view`: Consistent database view for creating transactions +/// - `task_ctx`: Shared context with trie updates and prefix sets +/// - `storage_worker_count`: Number of storage workers to spawn +/// - `account_worker_count`: Number of account workers to spawn +pub fn spawn_proof_workers( + executor: Handle, + view: ConsistentDbView, + task_ctx: ProofTaskCtx, storage_worker_count: usize, + account_worker_count: usize, +) -> ProviderResult +where + Factory: DatabaseProviderFactory, +{ + let (storage_work_tx, storage_work_rx) = unbounded::(); + let (account_work_tx, account_work_rx) = unbounded::(); - /// Sender for account worker jobs to worker pool. - account_work_tx: CrossbeamSender, + tracing::info!( + target: "trie::proof_task", + storage_worker_count, + account_worker_count, + "Spawning proof worker pools" + ); - /// Number of account workers successfully spawned. - account_worker_count: usize, + // Spawn storage workers + for worker_id in 0..storage_worker_count { + let provider_ro = view.provider_ro()?; + let tx = provider_ro.into_tx(); + let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id); + let work_rx_clone = storage_work_rx.clone(); - /// Receives proof task requests from [`ProofTaskManagerHandle`]. - proof_task_rx: CrossbeamReceiver, + executor + .spawn_blocking(move || storage_worker_loop(proof_task_tx, work_rx_clone, worker_id)); - /// Sender for creating handles that can queue tasks. - proof_task_tx: CrossbeamSender, + tracing::debug!( + target: "trie::proof_task", + worker_id, + "Storage worker spawned successfully" + ); + } - /// The number of active handles. - /// - /// Incremented in [`ProofTaskManagerHandle::new`] and decremented in - /// [`ProofTaskManagerHandle::drop`]. - active_handles: Arc, + // Spawn account workers + for worker_id in 0..account_worker_count { + let provider_ro = view.provider_ro()?; + let tx = provider_ro.into_tx(); + let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id); + let work_rx_clone = account_work_rx.clone(); + let storage_work_tx_clone = storage_work_tx.clone(); - /// Metrics tracking proof task operations. - #[cfg(feature = "metrics")] - metrics: ProofTaskMetrics, + executor.spawn_blocking(move || { + account_worker_loop(proof_task_tx, work_rx_clone, storage_work_tx_clone, worker_id) + }); + + tracing::debug!( + target: "trie::proof_task", + worker_id, + "Account worker spawned successfully" + ); + } + + Ok(ProofTaskManagerHandle::new( + storage_work_tx, + account_work_tx, + Arc::new(AtomicUsize::new(0)), + #[cfg(feature = "metrics")] + Arc::new(ProofTaskMetrics::default()), + )) } /// Worker loop for storage trie operations. @@ -284,8 +308,6 @@ fn storage_worker_loop( ); } -// TODO: Refactor this with storage_worker_loop. ProofTaskManager should be removed in the following -// pr and `MultiproofManager` should be used instead to dispatch jobs directly. /// Worker loop for account trie operations. /// /// # Lifecycle @@ -657,303 +679,6 @@ fn queue_storage_proofs( Ok(storage_proof_receivers) } -impl ProofTaskManager { - /// Creates a new [`ProofTaskManager`] with pre-spawned storage and account proof workers. - /// - /// This manager coordinates both storage and account worker pools: - /// - Storage workers handle `StorageProof` and `BlindedStorageNode` requests - /// - Account workers handle `AccountMultiproof` and `BlindedAccountNode` requests - /// - /// The `storage_worker_count` determines how many storage workers to spawn, and - /// `account_worker_count` determines how many account workers to spawn. - /// Returns an error if the underlying provider fails to create the transactions required for - /// spawning workers. - pub fn new( - executor: Handle, - view: ConsistentDbView, - task_ctx: ProofTaskCtx, - storage_worker_count: usize, - account_worker_count: usize, - ) -> ProviderResult - where - Factory: DatabaseProviderFactory, - { - // Use unbounded channel for the router to prevent account workers from blocking - // when queuing storage proofs. Account workers queue many storage proofs through - // this channel, and blocking on a bounded channel wastes parallel worker capacity. - let (proof_task_tx, proof_task_rx) = unbounded(); - - // Use unbounded channel to ensure all storage operations are queued to workers. - // This maintains transaction reuse benefits and avoids fallback to on-demand execution. - let (storage_work_tx, storage_work_rx) = unbounded::(); - let (account_work_tx, account_work_rx) = unbounded::(); - - tracing::info!( - target: "trie::proof_task", - storage_worker_count, - account_worker_count, - "Initializing storage and account worker pools with unbounded queues" - ); - - // Spawn storage workers - let spawned_storage_workers = Self::spawn_storage_workers( - &executor, - &view, - &task_ctx, - storage_worker_count, - storage_work_rx, - )?; - - // Spawn account workers with direct access to the storage worker queue - let spawned_account_workers = Self::spawn_account_workers( - &executor, - &view, - &task_ctx, - account_worker_count, - account_work_rx, - storage_work_tx.clone(), - )?; - - Ok(Self { - storage_work_tx, - storage_worker_count: spawned_storage_workers, - account_work_tx, - account_worker_count: spawned_account_workers, - proof_task_rx, - proof_task_tx, - active_handles: Arc::new(AtomicUsize::new(0)), - - #[cfg(feature = "metrics")] - metrics: ProofTaskMetrics::default(), - }) - } - - /// Returns a handle for sending new proof tasks to the [`ProofTaskManager`]. - pub fn handle(&self) -> ProofTaskManagerHandle { - ProofTaskManagerHandle::new(self.proof_task_tx.clone(), self.active_handles.clone()) - } - - /// Spawns a pool of storage workers with dedicated database transactions. - /// - /// Each worker receives `StorageWorkerJob` from the channel and processes storage proofs - /// and blinded storage node requests using a dedicated long-lived transaction. - /// - /// # Parameters - /// - `executor`: Tokio runtime handle for spawning blocking tasks - /// - `view`: Consistent database view for creating transactions - /// - `task_ctx`: Shared context with trie updates and prefix sets - /// - `worker_count`: Number of storage workers to spawn - /// - `work_rx`: Receiver for storage worker jobs - /// - /// # Returns - /// The number of storage workers successfully spawned - fn spawn_storage_workers( - executor: &Handle, - view: &ConsistentDbView, - task_ctx: &ProofTaskCtx, - worker_count: usize, - work_rx: CrossbeamReceiver, - ) -> ProviderResult - where - Factory: DatabaseProviderFactory, - { - let mut spawned_workers = 0; - - for worker_id in 0..worker_count { - let provider_ro = view.provider_ro()?; - let tx = provider_ro.into_tx(); - let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id); - let work_rx_clone = work_rx.clone(); - - executor.spawn_blocking(move || { - storage_worker_loop(proof_task_tx, work_rx_clone, worker_id) - }); - - spawned_workers += 1; - - tracing::debug!( - target: "trie::proof_task", - worker_id, - spawned_workers, - "Storage worker spawned successfully" - ); - } - - Ok(spawned_workers) - } - - /// Spawns a pool of account workers with dedicated database transactions. - /// - /// Each worker receives `AccountWorkerJob` from the channel and processes account multiproofs - /// and blinded account node requests using a dedicated long-lived transaction. Account workers - /// can delegate storage proof computation to the storage worker pool. - /// - /// # Parameters - /// - `executor`: Tokio runtime handle for spawning blocking tasks - /// - `view`: Consistent database view for creating transactions - /// - `task_ctx`: Shared context with trie updates and prefix sets - /// - `worker_count`: Number of account workers to spawn - /// - `work_rx`: Receiver for account worker jobs - /// - `storage_work_tx`: Sender to delegate storage proofs to storage worker pool - /// - /// # Returns - /// The number of account workers successfully spawned - fn spawn_account_workers( - executor: &Handle, - view: &ConsistentDbView, - task_ctx: &ProofTaskCtx, - worker_count: usize, - work_rx: CrossbeamReceiver, - storage_work_tx: CrossbeamSender, - ) -> ProviderResult - where - Factory: DatabaseProviderFactory, - { - let mut spawned_workers = 0; - - for worker_id in 0..worker_count { - let provider_ro = view.provider_ro()?; - let tx = provider_ro.into_tx(); - let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id); - let work_rx_clone = work_rx.clone(); - let storage_work_tx_clone = storage_work_tx.clone(); - - executor.spawn_blocking(move || { - account_worker_loop(proof_task_tx, work_rx_clone, storage_work_tx_clone, worker_id) - }); - - spawned_workers += 1; - - tracing::debug!( - target: "trie::proof_task", - worker_id, - spawned_workers, - "Account worker spawned successfully" - ); - } - - Ok(spawned_workers) - } - - /// Loops, managing the proof tasks, routing them to the appropriate worker pools. - /// - /// # Task Routing - /// - /// - **Storage Trie Operations** (`StorageProof` and `BlindedStorageNode`): Routed to - /// pre-spawned storage worker pool via unbounded channel. Returns error if workers are - /// disconnected (e.g., all workers panicked). - /// - **Account Trie Operations** (`AccountMultiproof` and `BlindedAccountNode`): Routed to - /// pre-spawned account worker pool via unbounded channel. Returns error if workers are - /// disconnected. - /// - /// # Shutdown - /// - /// On termination, `storage_work_tx` and `account_work_tx` are dropped, closing the channels - /// and signaling all workers to shut down gracefully. - pub fn run(mut self) -> ProviderResult<()> { - loop { - match self.proof_task_rx.recv() { - Ok(message) => { - match message { - ProofTaskMessage::QueueTask(task) => match task { - ProofTaskKind::StorageProof(input, sender) => { - self.storage_work_tx - .send(StorageWorkerJob::StorageProof { - input, - result_sender: sender, - }) - .expect("storage worker pool should be available"); - - tracing::trace!( - target: "trie::proof_task", - "Storage proof dispatched to worker pool" - ); - } - - ProofTaskKind::BlindedStorageNode(account, path, sender) => { - #[cfg(feature = "metrics")] - { - self.metrics.storage_nodes += 1; - } - - self.storage_work_tx - .send(StorageWorkerJob::BlindedStorageNode { - account, - path, - result_sender: sender, - }) - .expect("storage worker pool should be available"); - - tracing::trace!( - target: "trie::proof_task", - ?account, - ?path, - "Blinded storage node dispatched to worker pool" - ); - } - - ProofTaskKind::BlindedAccountNode(path, sender) => { - #[cfg(feature = "metrics")] - { - self.metrics.account_nodes += 1; - } - - self.account_work_tx - .send(AccountWorkerJob::BlindedAccountNode { - path, - result_sender: sender, - }) - .expect("account worker pool should be available"); - - tracing::trace!( - target: "trie::proof_task", - ?path, - "Blinded account node dispatched to worker pool" - ); - } - - ProofTaskKind::AccountMultiproof(input, sender) => { - self.account_work_tx - .send(AccountWorkerJob::AccountMultiproof { - input, - result_sender: sender, - }) - .expect("account worker pool should be available"); - - tracing::trace!( - target: "trie::proof_task", - "Account multiproof dispatched to worker pool" - ); - } - }, - ProofTaskMessage::Terminate => { - // Drop worker channels to signal workers to shut down - drop(self.storage_work_tx); - drop(self.account_work_tx); - - tracing::debug!( - target: "trie::proof_task", - storage_worker_count = self.storage_worker_count, - account_worker_count = self.account_worker_count, - "Shutting down proof task manager, signaling workers to terminate" - ); - - // Record metrics before terminating - #[cfg(feature = "metrics")] - self.metrics.record(); - - return Ok(()) - } - } - } - // All senders are disconnected, so we can terminate - // However this should never happen, as this struct stores a sender - Err(_) => return Ok(()), - }; - } - } -} - /// Type alias for the factory tuple returned by `create_factories` type ProofFactories<'a, Tx> = ( InMemoryTrieCursorFactory, &'a TrieUpdatesSorted>, @@ -969,8 +694,7 @@ pub struct ProofTaskTx { /// Trie updates, prefix sets, and state updates task_ctx: ProofTaskCtx, - /// Identifier for the tx within the context of a single [`ProofTaskManager`], used only for - /// tracing. + /// Identifier for the worker within the worker pool, used only for tracing. id: usize, } @@ -1135,9 +859,6 @@ struct AccountMultiproofParams<'a> { } /// Internal message for account workers. -/// -/// This is NOT exposed publicly. External callers use `ProofTaskKind::AccountMultiproof` or -/// `ProofTaskKind::BlindedAccountNode` which are routed through the manager's `std::mpsc` channel. #[derive(Debug)] enum AccountWorkerJob { /// Account multiproof computation request @@ -1181,77 +902,148 @@ impl ProofTaskCtx { } } -/// Message used to communicate with [`ProofTaskManager`]. -#[derive(Debug)] -pub enum ProofTaskMessage { - /// A request to queue a proof task. - QueueTask(ProofTaskKind), - /// A request to terminate the proof task manager. - Terminate, -} - -/// Proof task kind. +/// A handle that provides type-safe access to proof worker pools. /// -/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum -/// specifies the type of proof task to be executed. -#[derive(Debug)] -pub enum ProofTaskKind { - /// A storage proof request. - StorageProof(StorageProofInput, Sender), - /// A blinded account node request. - BlindedAccountNode(Nibbles, Sender), - /// A blinded storage node request. - BlindedStorageNode(B256, Nibbles, Sender), - /// An account multiproof request. - AccountMultiproof(AccountMultiproofInput, Sender), -} - -/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the -/// number of active handles went to zero. +/// The handle stores direct senders to both storage and account worker pools, +/// eliminating the need for a routing thread. All handles share reference-counted +/// channels, and workers shut down gracefully when all handles are dropped. #[derive(Debug)] pub struct ProofTaskManagerHandle { - /// The sender for the proof task manager. - sender: CrossbeamSender, - /// The number of active handles. + /// Direct sender to storage worker pool + storage_work_tx: CrossbeamSender, + /// Direct sender to account worker pool + account_work_tx: CrossbeamSender, + /// Active handle reference count for auto-termination active_handles: Arc, + /// Metrics tracking (lock-free) + #[cfg(feature = "metrics")] + metrics: Arc, } impl ProofTaskManagerHandle { - /// Creates a new [`ProofTaskManagerHandle`] with the given sender. - pub fn new( - sender: CrossbeamSender, + /// Creates a new [`ProofTaskManagerHandle`] with direct access to worker pools. + /// + /// This is an internal constructor used by `spawn_proof_workers`. + /// External users should call `spawn_proof_workers` to create handles. + fn new( + storage_work_tx: CrossbeamSender, + account_work_tx: CrossbeamSender, active_handles: Arc, + #[cfg(feature = "metrics")] metrics: Arc, ) -> Self { active_handles.fetch_add(1, Ordering::SeqCst); - Self { sender, active_handles } + Self { + storage_work_tx, + account_work_tx, + active_handles, + #[cfg(feature = "metrics")] + metrics, + } + } + + /// Queue a storage proof computation + pub fn queue_storage_proof( + &self, + input: StorageProofInput, + ) -> Result, ProviderError> { + let (tx, rx) = channel(); + self.storage_work_tx + .send(StorageWorkerJob::StorageProof { input, result_sender: tx }) + .map_err(|_| { + ProviderError::other(std::io::Error::other("storage workers unavailable")) + })?; + + #[cfg(feature = "metrics")] + self.metrics.storage_proofs.fetch_add(1, Ordering::Relaxed); + + Ok(rx) } - /// Queues a task to the proof task manager. - pub fn queue_task( + /// Queue an account multiproof computation + pub fn queue_account_multiproof( &self, - task: ProofTaskKind, - ) -> Result<(), crossbeam_channel::SendError> { - self.sender.send(ProofTaskMessage::QueueTask(task)) + input: AccountMultiproofInput, + ) -> Result, ProviderError> { + let (tx, rx) = channel(); + self.account_work_tx + .send(AccountWorkerJob::AccountMultiproof { input, result_sender: tx }) + .map_err(|_| { + ProviderError::other(std::io::Error::other("account workers unavailable")) + })?; + + #[cfg(feature = "metrics")] + self.metrics.account_proofs.fetch_add(1, Ordering::Relaxed); + + Ok(rx) } - /// Terminates the proof task manager. - pub fn terminate(&self) { - let _ = self.sender.send(ProofTaskMessage::Terminate); + /// Internal: Queue blinded storage node request + fn queue_blinded_storage_node( + &self, + account: B256, + path: Nibbles, + ) -> Result, ProviderError> { + let (tx, rx) = channel(); + self.storage_work_tx + .send(StorageWorkerJob::BlindedStorageNode { account, path, result_sender: tx }) + .map_err(|_| { + ProviderError::other(std::io::Error::other("storage workers unavailable")) + })?; + + #[cfg(feature = "metrics")] + self.metrics.storage_nodes.fetch_add(1, Ordering::Relaxed); + + Ok(rx) + } + + /// Internal: Queue blinded account node request + fn queue_blinded_account_node( + &self, + path: Nibbles, + ) -> Result, ProviderError> { + let (tx, rx) = channel(); + self.account_work_tx + .send(AccountWorkerJob::BlindedAccountNode { path, result_sender: tx }) + .map_err(|_| { + ProviderError::other(std::io::Error::other("account workers unavailable")) + })?; + + #[cfg(feature = "metrics")] + self.metrics.account_nodes.fetch_add(1, Ordering::Relaxed); + + Ok(rx) } } impl Clone for ProofTaskManagerHandle { fn clone(&self) -> Self { - Self::new(self.sender.clone(), self.active_handles.clone()) + Self::new( + self.storage_work_tx.clone(), + self.account_work_tx.clone(), + self.active_handles.clone(), + #[cfg(feature = "metrics")] + self.metrics.clone(), + ) } } impl Drop for ProofTaskManagerHandle { fn drop(&mut self) { - // Decrement the number of active handles and terminate the manager if it was the last - // handle. - if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 { - self.terminate(); + // Decrement the number of active handles. + // When the last handle is dropped, the channels are dropped and workers shut down. + // atomically grab the current handle count and decrement it for Drop. + let previous_handles = self.active_handles.fetch_sub(1, Ordering::SeqCst); + + debug_assert_ne!( + previous_handles, 0, + "active_handles underflow in ProofTaskManagerHandle::drop (previous={})", + previous_handles + ); + + #[cfg(feature = "metrics")] + if previous_handles == 1 { + // Flush metrics before exit. + self.metrics.record(); } } } @@ -1261,11 +1053,11 @@ impl TrieNodeProviderFactory for ProofTaskManagerHandle { type StorageNodeProvider = ProofTaskTrieNodeProvider; fn account_node_provider(&self) -> Self::AccountNodeProvider { - ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() } + ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() } } fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider { - ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() } + ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() } } } @@ -1274,35 +1066,34 @@ impl TrieNodeProviderFactory for ProofTaskManagerHandle { pub enum ProofTaskTrieNodeProvider { /// Blinded account trie node provider. AccountNode { - /// Sender to the proof task. - sender: CrossbeamSender, + /// Handle to the proof worker pools. + handle: ProofTaskManagerHandle, }, /// Blinded storage trie node provider. StorageNode { /// Target account. account: B256, - /// Sender to the proof task. - sender: CrossbeamSender, + /// Handle to the proof worker pools. + handle: ProofTaskManagerHandle, }, } impl TrieNodeProvider for ProofTaskTrieNodeProvider { fn trie_node(&self, path: &Nibbles) -> Result, SparseTrieError> { - let (tx, rx) = channel(); match self { - Self::AccountNode { sender } => { - let _ = sender.send(ProofTaskMessage::QueueTask( - ProofTaskKind::BlindedAccountNode(*path, tx), - )); + Self::AccountNode { handle } => { + let rx = handle + .queue_blinded_account_node(*path) + .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?; + rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))? } - Self::StorageNode { sender, account } => { - let _ = sender.send(ProofTaskMessage::QueueTask( - ProofTaskKind::BlindedStorageNode(*account, *path, tx), - )); + Self::StorageNode { handle, account } => { + let rx = handle + .queue_blinded_storage_node(*account, *path) + .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?; + rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))? } } - - rx.recv().unwrap() } } @@ -1329,9 +1120,9 @@ mod tests { ) } - /// Ensures `max_concurrency` is independent of storage and account workers. + /// Ensures `spawn_proof_workers` spawns workers correctly. #[test] - fn proof_task_manager_independent_pools() { + fn spawn_proof_workers_creates_handle() { let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap(); runtime.block_on(async { let handle = tokio::runtime::Handle::current(); @@ -1339,13 +1130,13 @@ mod tests { let view = ConsistentDbView::new(factory, None); let ctx = test_ctx(); - let manager = ProofTaskManager::new(handle.clone(), view, ctx, 5, 3).unwrap(); - // With storage_worker_count=5, we get exactly 5 storage workers - assert_eq!(manager.storage_worker_count, 5); - // With account_worker_count=3, we get exactly 3 account workers - assert_eq!(manager.account_worker_count, 3); + let proof_handle = spawn_proof_workers(handle.clone(), view, ctx, 5, 3).unwrap(); + + // Verify handle can be cloned + let _cloned_handle = proof_handle.clone(); - drop(manager); + // Workers shut down automatically when handle is dropped + drop(proof_handle); task::yield_now().await; }); } diff --git a/crates/trie/parallel/src/proof_task_metrics.rs b/crates/trie/parallel/src/proof_task_metrics.rs index cdb59d078d8..5f8959cea9d 100644 --- a/crates/trie/parallel/src/proof_task_metrics.rs +++ b/crates/trie/parallel/src/proof_task_metrics.rs @@ -1,21 +1,29 @@ use reth_metrics::{metrics::Histogram, Metrics}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; -/// Metrics for blinded node fetching for the duration of the proof task manager. +/// Metrics for blinded node fetching by proof workers. #[derive(Clone, Debug, Default)] pub struct ProofTaskMetrics { /// The actual metrics for blinded nodes. pub task_metrics: ProofTaskTrieMetrics, - /// Count of blinded account node requests. - pub account_nodes: usize, - /// Count of blinded storage node requests. - pub storage_nodes: usize, + /// Count of storage proof requests (lock-free). + pub storage_proofs: Arc, + /// Count of account proof requests (lock-free). + pub account_proofs: Arc, + /// Count of blinded account node requests (lock-free). + pub account_nodes: Arc, + /// Count of blinded storage node requests (lock-free). + pub storage_nodes: Arc, } impl ProofTaskMetrics { /// Record the blinded node counts into the histograms. pub fn record(&self) { - self.task_metrics.record_account_nodes(self.account_nodes); - self.task_metrics.record_storage_nodes(self.storage_nodes); + self.task_metrics.record_account_nodes(self.account_nodes.load(Ordering::Relaxed) as usize); + self.task_metrics.record_storage_nodes(self.storage_nodes.load(Ordering::Relaxed) as usize); } }