Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions crates/engine/primitives/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,27 @@ pub const DEFAULT_MEMORY_BLOCK_BUFFER_TARGET: u64 = 0;
/// Default maximum concurrency for on-demand proof tasks (blinded nodes)
pub const DEFAULT_MAX_PROOF_TASK_CONCURRENCY: u64 = 256;

/// Minimum number of workers we allow configuring explicitly.
pub const MIN_WORKER_COUNT: usize = 2;

/// Clamps the worker count to the minimum allowed value.
///
/// Ensures that the worker count is at least [`MIN_WORKER_COUNT`].
const fn clamp_worker_count(count: usize) -> usize {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just .max(MIN_WORKER_COUNT)? Let's move this to with_*_worker_count functions, no need to have a separate helper fn for this

if count >= MIN_WORKER_COUNT {
count
} else {
MIN_WORKER_COUNT
}
}

/// Returns the default number of storage worker threads based on available parallelism.
fn default_storage_worker_count() -> usize {
#[cfg(feature = "std")]
{
std::thread::available_parallelism().map(|n| (n.get() * 2).clamp(2, 64)).unwrap_or(8)
std::thread::available_parallelism()
.map(|n| (n.get() * 2).clamp(MIN_WORKER_COUNT, 64))
.unwrap_or(8)
}
#[cfg(not(feature = "std"))]
{
Expand All @@ -28,7 +44,7 @@ fn default_storage_worker_count() -> usize {
/// so we use higher concurrency (1.5x storage workers) to maximize throughput and overlap.
/// While storage workers are CPU-bound, account workers are I/O-bound coordinators.
fn default_account_worker_count() -> usize {
(default_storage_worker_count() * 3) / 2
((default_storage_worker_count() * 3) / 2).max(MIN_WORKER_COUNT)
}

/// The size of proof targets chunk to spawn in one multiproof calculation.
Expand Down Expand Up @@ -494,7 +510,7 @@ impl TreeConfig {

/// Setter for the number of storage proof worker threads.
pub const fn with_storage_worker_count(mut self, storage_worker_count: usize) -> Self {
self.storage_worker_count = storage_worker_count;
self.storage_worker_count = clamp_worker_count(storage_worker_count);
self
}

Expand All @@ -505,7 +521,7 @@ impl TreeConfig {

/// Setter for the number of account proof worker threads.
pub const fn with_account_worker_count(mut self, account_worker_count: usize) -> Self {
self.account_worker_count = account_worker_count;
self.account_worker_count = clamp_worker_count(account_worker_count);
self
}
}
25 changes: 6 additions & 19 deletions crates/engine/tree/src/tree/payload_processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use reth_provider::{
use reth_revm::{db::BundleState, state::EvmState};
use reth_trie::TrieInput;
use reth_trie_parallel::{
proof_task::{ProofTaskCtx, ProofTaskManager},
proof_task::{spawn_proof_workers, ProofTaskCtx},
root::ParallelStateRootError,
};
use reth_trie_sparse::{
Expand Down Expand Up @@ -167,8 +167,7 @@ where
/// This returns a handle to await the final state root and to interact with the tasks (e.g.
/// canceling)
///
/// Returns an error with the original transactions iterator if the proof task manager fails to
/// initialize.
/// Returns an error with the original transactions iterator if proof worker spawning fails.
#[allow(clippy::type_complexity)]
pub fn spawn<P, I: ExecutableTxIterator<Evm>>(
&mut self,
Expand Down Expand Up @@ -204,14 +203,14 @@ where
let storage_worker_count = config.storage_worker_count();
let account_worker_count = config.account_worker_count();
let max_proof_task_concurrency = config.max_proof_task_concurrency() as usize;
let proof_task = match ProofTaskManager::new(
let proof_handle = match spawn_proof_workers(
self.executor.handle().clone(),
consistent_view,
task_ctx,
storage_worker_count,
account_worker_count,
) {
Ok(task) => task,
Ok(handle) => handle,
Err(error) => {
return Err((error, transactions, env, provider_builder));
}
Expand All @@ -223,7 +222,7 @@ where
let multi_proof_task = MultiProofTask::new(
state_root_config,
self.executor.clone(),
proof_task.handle(),
proof_handle.clone(),
to_sparse_trie,
max_multi_proof_task_concurrency,
config.multiproof_chunking_enabled().then_some(config.multiproof_chunk_size()),
Expand Down Expand Up @@ -252,19 +251,7 @@ where
let (state_root_tx, state_root_rx) = channel();

// Spawn the sparse trie task using any stored trie and parallel trie configuration.
self.spawn_sparse_trie_task(sparse_trie_rx, proof_task.handle(), state_root_tx);

// spawn the proof task
self.executor.spawn_blocking(move || {
if let Err(err) = proof_task.run() {
// At least log if there is an error at any point
tracing::error!(
target: "engine::root",
?err,
"Storage proof task returned an error"
);
}
});
self.spawn_sparse_trie_task(sparse_trie_rx, proof_handle, state_root_tx);

Ok(PayloadHandle {
to_multi_proof,
Expand Down
28 changes: 11 additions & 17 deletions crates/engine/tree/src/tree/payload_processor/multiproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use reth_trie::{
};
use reth_trie_parallel::{
proof::ParallelProof,
proof_task::{AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle},
proof_task::{AccountMultiproofInput, ProofTaskManagerHandle},
root::ParallelStateRootError,
};
use std::{
Expand Down Expand Up @@ -346,10 +346,9 @@ pub struct MultiproofManager {
pending: VecDeque<PendingMultiproofTask>,
/// Executor for tasks
executor: WorkloadExecutor,
/// Handle to the proof task manager used for creating `ParallelProof` instances for storage
/// proofs.
/// Handle to the proof worker pool for storage proofs.
storage_proof_task_handle: ProofTaskManagerHandle,
/// Handle to the proof task manager used for account multiproofs.
/// Handle to the proof worker pool for account multiproofs.
account_proof_task_handle: ProofTaskManagerHandle,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why do we need to separate them here, if it's always the same proof_task_handle passed to both fields?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the note

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the note, addresssed them

/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
Expand Down Expand Up @@ -556,15 +555,10 @@ impl MultiproofManager {
missed_leaves_storage_roots,
};

let (sender, receiver) = channel();
let proof_result: Result<DecodedMultiProof, ParallelStateRootError> = (|| {
account_proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof to worker pool".into(),
)
})?;
let receiver = account_proof_task_handle
.queue_account_multiproof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

receiver
.recv()
Expand Down Expand Up @@ -1223,7 +1217,7 @@ mod tests {
DatabaseProviderFactory,
};
use reth_trie::{MultiProof, TrieInput};
use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager};
use reth_trie_parallel::proof_task::{spawn_proof_workers, ProofTaskCtx};
use revm_primitives::{B256, U256};

fn create_test_state_root_task<F>(factory: F) -> MultiProofTask
Expand All @@ -1238,12 +1232,12 @@ mod tests {
config.prefix_sets.clone(),
);
let consistent_view = ConsistentDbView::new(factory, None);
let proof_task =
ProofTaskManager::new(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
.expect("Failed to create ProofTaskManager");
let proof_handle =
spawn_proof_workers(executor.handle().clone(), consistent_view, task_ctx, 1, 1)
.expect("Failed to spawn proof workers");
let channel = channel();

MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1, None)
MultiProofTask::new(config, executor, proof_handle, channel.0, 1, None)
}

#[test]
Expand Down
5 changes: 2 additions & 3 deletions crates/engine/tree/src/tree/payload_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -890,13 +890,12 @@ where
(handle, StateRootStrategy::StateRootTask)
}
Err((error, txs, env, provider_builder)) => {
// Failed to initialize proof task manager, fallback to parallel state
// root
// Failed to spawn proof workers, fallback to parallel state root
error!(
target: "engine::tree",
block=?block_num_hash,
?error,
"Failed to initialize proof task manager, falling back to parallel state root"
"Failed to spawn proof workers, falling back to parallel state root"
);
(
self.payload_processor.spawn_cache_exclusive(
Expand Down
52 changes: 19 additions & 33 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::{
metrics::ParallelTrieMetrics,
proof_task::{
AccountMultiproofInput, ProofTaskKind, ProofTaskManagerHandle, StorageProofInput,
},
proof_task::{AccountMultiproofInput, ProofTaskManagerHandle, StorageProofInput},
root::ParallelStateRootError,
StorageRootTargets,
};
Expand All @@ -16,10 +14,7 @@ use reth_trie::{
DecodedMultiProof, DecodedStorageMultiProof, HashedPostStateSorted, MultiProofTargets, Nibbles,
};
use reth_trie_common::added_removed_keys::MultiAddedRemovedKeys;
use std::sync::{
mpsc::{channel, Receiver},
Arc,
};
use std::sync::{mpsc::Receiver, Arc};
use tracing::trace;

/// Parallel proof calculator.
Expand All @@ -41,7 +36,7 @@ pub struct ParallelProof {
collect_branch_node_masks: bool,
/// Provided by the user to give the necessary context to retain extra proofs.
multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
/// Handle to the proof task manager.
/// Handle to the proof worker pools.
proof_task_handle: ProofTaskManagerHandle,
/// Cached storage proof roots for missed leaves; this maps
/// hashed (missed) addresses to their storage proof roots.
Expand Down Expand Up @@ -93,7 +88,10 @@ impl ParallelProof {
hashed_address: B256,
prefix_set: PrefixSet,
target_slots: B256Set,
) -> Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>> {
) -> Result<
Receiver<Result<DecodedStorageMultiProof, ParallelStateRootError>>,
ParallelStateRootError,
> {
let input = StorageProofInput::new(
hashed_address,
prefix_set,
Expand All @@ -102,9 +100,9 @@ impl ParallelProof {
self.multi_added_removed_keys.clone(),
);

let (sender, receiver) = std::sync::mpsc::channel();
let _ = self.proof_task_handle.queue_task(ProofTaskKind::StorageProof(input, sender));
receiver
self.proof_task_handle
.queue_storage_proof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
}

/// Generate a storage multiproof according to the specified targets and hashed address.
Expand All @@ -124,7 +122,7 @@ impl ParallelProof {
"Starting storage proof generation"
);

let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots);
let receiver = self.queue_storage_proof(hashed_address, prefix_set, target_slots)?;
let proof_result = receiver.recv().map_err(|_| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(DatabaseError::Other(
format!("channel closed for {hashed_address}"),
Expand Down Expand Up @@ -193,15 +191,10 @@ impl ParallelProof {
missed_leaves_storage_roots: self.missed_leaves_storage_roots.clone(),
};

let (sender, receiver) = channel();
self.proof_task_handle
.queue_task(ProofTaskKind::AccountMultiproof(input, sender))
.map_err(|_| {
ParallelStateRootError::Other(
"Failed to queue account multiproof: account worker pool unavailable"
.to_string(),
)
})?;
let receiver = self
.proof_task_handle
.queue_account_multiproof(input)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))?;

// Wait for account multiproof result from worker
let (multiproof, stats) = receiver.recv().map_err(|_| {
Expand Down Expand Up @@ -231,7 +224,7 @@ impl ParallelProof {
#[cfg(test)]
mod tests {
use super::*;
use crate::proof_task::{ProofTaskCtx, ProofTaskManager};
use crate::proof_task::{spawn_proof_workers, ProofTaskCtx};
use alloy_primitives::{
keccak256,
map::{B256Set, DefaultHashBuilder, HashMap},
Expand Down Expand Up @@ -313,13 +306,8 @@ mod tests {

let task_ctx =
ProofTaskCtx::new(Default::default(), Default::default(), Default::default());
let proof_task =
ProofTaskManager::new(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();
let proof_task_handle = proof_task.handle();

// keep the join handle around to make sure it does not return any errors
// after we compute the state root
let join_handle = rt.spawn_blocking(move || proof_task.run());
let proof_task_handle =
spawn_proof_workers(rt.handle().clone(), consistent_view, task_ctx, 1, 1).unwrap();

let parallel_result = ParallelProof::new(
Default::default(),
Expand Down Expand Up @@ -354,9 +342,7 @@ mod tests {
// then compare the entire thing for any mask differences
assert_eq!(parallel_result, sequential_result_decoded);

// drop the handle to terminate the task and then block on the proof task handle to make
// sure it does not return any errors
// Workers shut down automatically when handle is dropped
drop(proof_task_handle);
rt.block_on(join_handle).unwrap().expect("The proof task should not return an error");
}
}
Loading