From 143ce6685e41154c960e04a1eea556439288f7ca Mon Sep 17 00:00:00 2001 From: lightsing Date: Tue, 16 Sep 2025 16:30:59 +0800 Subject: [PATCH] poc --- Cargo.lock | 16 + Cargo.toml | 2 +- crates/stateless/src/validation.rs | 13 + crates/trie/linked/Cargo.toml | 27 + crates/trie/linked/src/execution_witness.rs | 130 +++ crates/trie/linked/src/lib.rs | 170 ++++ crates/trie/linked/src/mpt.rs | 1004 +++++++++++++++++++ 7 files changed, 1361 insertions(+), 1 deletion(-) create mode 100644 crates/trie/linked/Cargo.toml create mode 100644 crates/trie/linked/src/execution_witness.rs create mode 100644 crates/trie/linked/src/lib.rs create mode 100644 crates/trie/linked/src/mpt.rs diff --git a/Cargo.lock b/Cargo.lock index 58eb17d4cc9..de2c4c50e75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10678,6 +10678,22 @@ dependencies = [ "triehash", ] +[[package]] +name = "reth-trie-linked" +version = "1.7.0" +dependencies = [ + "alloy-primitives", + "alloy-rlp", + "alloy-trie", + "itertools 0.14.0", + "reth-stateless", + "reth-storage-errors", + "reth-trie-common", + "revm-bytecode", + "serde", + "thiserror 2.0.16", +] + [[package]] name = "reth-trie-parallel" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index a4e39d05e98..a8fac0eb1db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -173,7 +173,7 @@ members = [ "testing/ef-tests/", "testing/testing-utils", "testing/runner", - "crates/tracing-otlp", + "crates/tracing-otlp", "crates/trie/linked", ] default-members = ["bin/reth"] exclude = ["docs/cli"] diff --git a/crates/stateless/src/validation.rs b/crates/stateless/src/validation.rs index 23308bcfa55..5fa2acf9698 100644 --- a/crates/stateless/src/validation.rs +++ b/crates/stateless/src/validation.rs @@ -86,6 +86,19 @@ pub enum StatelessValidationError { expected: B256, }, + /// Error when the computed pre-storage root does not match the expected one. + #[error( + "mismatched pre-storage root of account with hash {address_hash}: {got} \n {expected}" + )] + PreStorageRootMismatch { + /// The account address hash + address_hash: B256, + /// The computed pre-state root + got: B256, + /// The expected pre-state root from the previous block + expected: B256, + }, + /// Custom error. #[error("{0}")] Custom(&'static str), diff --git a/crates/trie/linked/Cargo.toml b/crates/trie/linked/Cargo.toml new file mode 100644 index 00000000000..0ff92a4271d --- /dev/null +++ b/crates/trie/linked/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "reth-trie-linked" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +exclude.workspace = true + +[lints] +workspace = true + +[dependencies] +alloy-primitives.workspace = true +alloy-rlp.workspace = true +alloy-trie.workspace = true +reth-stateless.workspace = true +reth-storage-errors.workspace = true +reth-trie-common.workspace = true +revm-bytecode.workspace = true +serde = { workspace = true, optional = true } +thiserror.workspace = true +itertools.workspace = true + +[features] +serde = ["dep:serde", "serde/derive", "alloy-primitives/serde"] diff --git a/crates/trie/linked/src/execution_witness.rs b/crates/trie/linked/src/execution_witness.rs new file mode 100644 index 00000000000..f1ac8ed7385 --- /dev/null +++ b/crates/trie/linked/src/execution_witness.rs @@ -0,0 +1,130 @@ +//! This is copied and modified from https://github.com/succinctlabs/rsp +//! crates/mpt/src/execution_witness.rs rev@2a99f35a9b81452eb53af3848e50addfd481363c +//! Under MIT license + +use crate::mpt::{resolve_nodes, MptNode, MptNodeData, MptNodeReference}; +use alloy_primitives::{ + keccak256, + map::{B256Map, HashMap}, + Bytes, B256, +}; +use alloy_rlp::Decodable; +use alloy_trie::TrieAccount; +use reth_stateless::validation::StatelessValidationError; + +// Builds tries from the witness state. +// +// NOTE: This method should be called outside zkVM! In general, you construct tries, then +// validate them inside zkVM. +pub(crate) fn build_validated_tries<'a, I>( + pre_state_root: B256, + states: I, +) -> Result<(MptNode, B256Map), StatelessValidationError> +where + I: IntoIterator, +{ + let err_fn = |_e| StatelessValidationError::WitnessRevealFailed { pre_state_root }; + + // Step 1: Decode all RLP-encoded trie nodes and index by hash + // IMPORTANT: Witness state contains both *state trie* nodes and *storage tries* nodes! + let mut node_map = HashMap::::default(); + let mut node_by_hash = B256Map::::default(); + let mut root_node: Option = None; + + for encoded in states.into_iter() { + let node = MptNode::decode(&mut encoded.as_ref()).map_err(err_fn)?; + let hash = keccak256(encoded); + if hash == pre_state_root { + root_node = Some(node.clone()); + } + node_by_hash.insert(hash, node.clone()); + node_map.insert(node.reference(), node); + } + + // Step 2: Use root_node or fallback to Digest + let root = root_node.unwrap_or_else(|| MptNodeData::Digest(pre_state_root).into()); + + // Build state trie. + let mut raw_storage_tries = Vec::with_capacity(node_by_hash.len()); + let state_trie = resolve_nodes(&root, &node_map); + + state_trie.try_for_each_leaves(|key, mut value| { + let account = TrieAccount::decode(&mut value).map_err(err_fn)?; + let hashed_address = B256::from_slice(key); + raw_storage_tries.push((hashed_address, account.storage_root)); + Ok::<_, StatelessValidationError>(()) + })?; + + // Step 3: Build storage tries per account efficiently + let mut storage_tries = + B256Map::::with_capacity_and_hasher(raw_storage_tries.len(), Default::default()); + + for (hashed_address, storage_root) in raw_storage_tries { + let root_node = match node_by_hash.get(&storage_root).cloned() { + Some(node) => node, + None => { + // An execution witness can include an account leaf (with non-empty storageRoot), + // but omit its entire storage trie when that account's storage was + // NOT touched during the block. + continue; + } + }; + let storage_trie = resolve_nodes(&root_node, &node_map); + + if storage_trie.is_digest() { + return Err(StatelessValidationError::WitnessRevealFailed { pre_state_root }); + } + + // Insert resolved storage trie. + storage_tries.insert(hashed_address, storage_trie); + } + + // Step 3a: Verify that state_trie was built correctly - confirm tree hash with pre-state root. + validate_state_trie(&state_trie, pre_state_root)?; + + // Step 3b: Verify that each storage trie matches the declared storage_root in the state trie. + validate_storage_tries(pre_state_root, &state_trie, &storage_tries)?; + + Ok((state_trie, storage_tries)) +} + +// Validate that state_trie was built correctly - confirm tree hash with prev state root. +fn validate_state_trie( + state_trie: &MptNode, + pre_state_root: B256, +) -> Result<(), StatelessValidationError> { + if state_trie.hash() != pre_state_root { + return Err(StatelessValidationError::PreStateRootMismatch { + got: state_trie.hash(), + expected: pre_state_root, + }); + } + Ok(()) +} + +// Validates that each storage trie matches the declared storage_root in the state trie. +fn validate_storage_tries( + pre_state_root: B256, + state_trie: &MptNode, + storage_tries: &B256Map, +) -> Result<(), StatelessValidationError> { + for (address_hash, storage_trie) in storage_tries.iter() { + let account = state_trie + .get_rlp::(address_hash.as_slice()) + .map_err(|_e| StatelessValidationError::WitnessRevealFailed { pre_state_root })? + .ok_or(StatelessValidationError::WitnessRevealFailed { pre_state_root })?; + + let expected = account.storage_root; + let got = storage_trie.hash(); + + if expected != got { + return Err(StatelessValidationError::PreStorageRootMismatch { + address_hash: *address_hash, + expected, + got, + }); + } + } + + Ok(()) +} diff --git a/crates/trie/linked/src/lib.rs b/crates/trie/linked/src/lib.rs new file mode 100644 index 00000000000..56fec2fd0f0 --- /dev/null +++ b/crates/trie/linked/src/lib.rs @@ -0,0 +1,170 @@ +//! Linked Partial Merkle Patricia Trie optimized for stateless execution in zkvm + +use crate::mpt::MptNode; +use alloy_primitives::{keccak256, map::B256Map, Address, Bytes, B256, U256}; +use alloy_trie::{TrieAccount, EMPTY_ROOT_HASH}; +use itertools::Itertools; +use reth_stateless::{validation::StatelessValidationError, ExecutionWitness, StatelessTrie}; +use reth_storage_errors::ProviderError; +use reth_trie_common::HashedPostState; +use revm_bytecode::Bytecode; + +mod execution_witness; +mod mpt; + +/// A partial trie that can be updated +#[derive(Clone, Debug, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct LinkedSparseTrie { + /// The state trie + pub state_trie: MptNode, + /// The storage tries, keyed by the keccak256 hash of the account address + pub storage_tries: B256Map, +} + +impl LinkedSparseTrie { + /// Create a partial state trie from a previous state root and a list of RLP-encoded MPT nodes + pub fn new<'a, I>( + prev_state_root: B256, + states: I, + ) -> Result + where + I: IntoIterator, + { + let (state_trie, storage_tries) = + execution_witness::build_validated_tries(prev_state_root, states)?; + + Ok(LinkedSparseTrie { state_trie, storage_tries }) + } + + /// Get account by address + pub fn account(&self, address: Address) -> Result, ProviderError> { + let hashed_address = keccak256(address); + let account = self.state_trie.get_rlp::(&*hashed_address)?; + + Ok(account) + } + + /// Get storage value of an account at a specific slot. + pub fn storage(&self, address: Address, slot: U256) -> Result { + let hashed_address = keccak256(address); + + // Usual case, where given storage slot is present. + if let Some(storage_trie) = self.storage_tries.get(&hashed_address) { + let key = keccak256(slot.to_be_bytes::<32>()); + let value = storage_trie.get_rlp::(&*key)?.unwrap_or_default(); + return Ok(value); + } + + // Storage slot value is not present in the trie, validate that the witness is complete. + // FIXME: do we need to validate that the storage trie is empty? seems get_rlp has already + // proven that. + let account = self.state_trie.get_rlp::(&*hashed_address)?; + match account { + Some(account) => { + if account.storage_root != EMPTY_ROOT_HASH { + // if an account has a non-empty storage root, then during the building of the + // trie, we already failed if the storage trie was not present. + unreachable!("pre-built storage trie shall be present"); + } + } + None => { + todo!("Validate that account witness is valid"); + } + } + + // Account doesn't exist or has empty storage root. + Ok(U256::ZERO) + } + + /// Mutates state based on diffs provided in [`HashedPostState`]. + pub fn calculate_state_root( + &mut self, + state: HashedPostState, + ) -> Result { + let err_fn = |_e| StatelessValidationError::StatelessStateRootCalculationFailed; + // 1. Apply storage‑slot updates and compute each contract’s storage root + // + // + // We walk over every (address, storage) pair in deterministic order + // and update the corresponding per‑account storage trie in‑place. + for (address, storage) in + state.storages.into_iter().sorted_unstable_by_key(|(addr, _)| *addr) + { + let storage_trie = self.storage_tries.entry(address).or_default(); + if storage.wiped { + storage_trie.clear(); + } + + // Apply slot‑level changes + for (hashed_slot, value) in + storage.storage.into_iter().sorted_unstable_by_key(|(slot, _)| *slot) + { + if value.is_zero() { + storage_trie.delete(&*hashed_slot).map_err(err_fn)?; + } else { + storage_trie.insert_rlp(&*hashed_slot, value).map_err(err_fn)?; + } + } + } + + for (hashed_address, account) in + state.accounts.into_iter().sorted_unstable_by_key(|(addr, _)| *addr) + { + let Some(account) = account else { + self.state_trie.delete(&*hashed_address).map_err(err_fn)?; + continue; + }; + + // Determine which storage root should be used for this account + let storage_root = if let Some(storage_trie) = self.storage_tries.get(&hashed_address) { + storage_trie.hash() + } else if let Some(value) = + self.state_trie.get_rlp::(&*hashed_address).map_err(err_fn)? + { + value.storage_root + } else { + EMPTY_ROOT_HASH + }; + + let account = account.into_trie_account(storage_root); + self.state_trie.insert_rlp(&*hashed_address, account).map_err(err_fn)?; + } + + Ok(self.state_trie.hash()) + } +} + +impl StatelessTrie for LinkedSparseTrie { + fn new( + witness: &ExecutionWitness, + pre_state_root: B256, + ) -> Result<(Self, B256Map), StatelessValidationError> + where + Self: Sized, + { + let mut bytecode = B256Map::default(); + for code in witness.codes.iter() { + let hash = keccak256(code); + bytecode.insert(hash, Bytecode::new_raw(code.clone())); + } + + let trie = Self::new(pre_state_root, &witness.state)?; + Ok((trie, bytecode)) + } + + fn account(&self, address: Address) -> Result, ProviderError> { + LinkedSparseTrie::account(self, address) + } + + fn storage(&self, address: Address, slot: U256) -> Result { + LinkedSparseTrie::storage(self, address, slot) + } + + fn calculate_state_root( + &mut self, + state: HashedPostState, + ) -> Result { + LinkedSparseTrie::calculate_state_root(self, state) + } +} diff --git a/crates/trie/linked/src/mpt.rs b/crates/trie/linked/src/mpt.rs new file mode 100644 index 00000000000..fea63a0489b --- /dev/null +++ b/crates/trie/linked/src/mpt.rs @@ -0,0 +1,1004 @@ +//! This is copied and modified from https://github.com/succinctlabs/rsp +//! crates/mpt/src/mpt.rs rev@2a99f35a9b81452eb53af3848e50addfd481363c +//! Under MIT license +// This code is modified from the original implementation of Zeth. +// +// Reference: https://github.com/risc0/zeth +// +// Copyright 2023 RISC Zero, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use alloy_primitives::{keccak256, map::HashMap, B256}; +use alloy_rlp::{Decodable, Encodable, Header, EMPTY_STRING_CODE}; +use alloy_trie::{Nibbles, EMPTY_ROOT_HASH}; +use reth_storage_errors::ProviderError; +use std::{cmp, convert::Infallible, fmt::Debug, iter, mem, sync::Mutex}; + +trait RlpBytes { + /// Returns the RLP-encoding. + fn to_rlp(&self) -> Vec; +} + +impl RlpBytes for T +where + T: Encodable, +{ + #[inline] + fn to_rlp(&self) -> Vec { + let rlp_length = self.length(); + let mut out = Vec::with_capacity(rlp_length); + self.encode(&mut out); + debug_assert_eq!(out.len(), rlp_length); + out + } +} + +/// Represents the root node of a sparse Merkle Patricia Trie. +/// +/// The "sparse" nature of this trie allows for truncation of certain unneeded parts, +/// representing them by their node hash. This design choice is particularly useful for +/// optimizing storage. However, operations targeting a truncated part will fail and +/// return an error. Another distinction of this implementation is that branches cannot +/// store values, aligning with the construction of MPTs in Ethereum. +#[derive(Debug, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct MptNode { + /// The type and data of the node. + data: MptNodeData, + /// Cache for a previously computed reference of this node. This is skipped during + /// serialization. + #[cfg_attr(feature = "serde", serde(skip))] + cached_reference: Mutex>, +} + +impl Ord for MptNode { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.data.cmp(&other.data) + } +} + +impl PartialOrd for MptNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for MptNode {} + +impl PartialEq for MptNode { + fn eq(&self, other: &Self) -> bool { + self.data == other.data + } +} + +impl Clone for MptNode { + fn clone(&self) -> Self { + Self { + data: self.data.clone(), + cached_reference: Mutex::new(self.cached_reference.lock().unwrap().clone()), + } + } +} + +/// Represents the various types of data that can be stored within a node in the sparse +/// Merkle Patricia Trie (MPT). +/// +/// Each node in the trie can be of one of several types, each with its own specific data +/// structure. This enum provides a clear and type-safe way to represent the data +/// associated with each node type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum MptNodeData { + /// Represents an empty trie node. + #[default] + Null, + /// A node that can have up to 16 children. Each child is an optional boxed [MptNode]. + Branch([Option>; 16]), + /// A leaf node that contains a key and a value, both represented as byte vectors. + Leaf(Vec, Vec), + /// A node that has exactly one child and is used to represent a shared prefix of + /// several keys. + Extension(Vec, Box), + /// Represents a sub-trie by its hash, allowing for efficient storage of large + /// sub-tries without storing their entire content. + Digest(B256), +} + +/// Represents the ways in which one node can reference another node inside the sparse +/// Merkle Patricia Trie (MPT). +/// +/// Nodes in the MPT can reference other nodes either directly through their byte +/// representation or indirectly through a hash of their encoding. This enum provides a +/// clear and type-safe way to represent these references. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum MptNodeReference { + /// Represents a direct reference to another node using its byte encoding. Typically + /// used for short encodings that are less than 32 bytes in length. + Bytes(Vec), + /// Represents an indirect reference to another node using the Keccak hash of its long + /// encoding. Used for encodings that are not less than 32 bytes in length. + Digest(B256), +} + +/// Provides a conversion from [MptNodeData] to [MptNode]. +/// +/// This implementation allows for conversion from [MptNodeData] to [MptNode], +/// initializing the `data` field with the provided value and setting the +/// `cached_reference` field to `None`. +impl From for MptNode { + fn from(value: MptNodeData) -> Self { + Self { data: value, cached_reference: Mutex::new(None) } + } +} + +/// Provides encoding functionalities for the `MptNode` type. +/// +/// This implementation allows for the serialization of an [MptNode] into its RLP-encoded +/// form. The encoding is done based on the type of node data ([MptNodeData]) it holds. +impl Encodable for MptNode { + /// Encodes the node into the provided `out` buffer. + /// + /// The encoding is done using the Recursive Length Prefix (RLP) encoding scheme. The + /// method handles different node data types and encodes them accordingly. + #[inline] + fn encode(&self, out: &mut dyn alloy_rlp::BufMut) { + match &self.data { + MptNodeData::Null => { + out.put_u8(EMPTY_STRING_CODE); + } + MptNodeData::Branch(nodes) => { + Header { list: true, payload_length: self.payload_length() }.encode(out); + nodes.iter().for_each(|child| match child { + Some(node) => node.reference_encode(out), + None => out.put_u8(EMPTY_STRING_CODE), + }); + // in the MPT reference, branches have values so always add empty value + out.put_u8(EMPTY_STRING_CODE); + } + MptNodeData::Leaf(prefix, value) => { + Header { list: true, payload_length: self.payload_length() }.encode(out); + prefix.as_slice().encode(out); + value.as_slice().encode(out); + } + MptNodeData::Extension(prefix, node) => { + Header { list: true, payload_length: self.payload_length() }.encode(out); + prefix.as_slice().encode(out); + node.reference_encode(out); + } + MptNodeData::Digest(digest) => { + digest.encode(out); + } + } + } + + /// Returns the length of the encoded node in bytes. + /// + /// This method calculates the length of the RLP-encoded node. It's useful for + /// determining the size requirements for storage or transmission. + #[inline] + fn length(&self) -> usize { + let payload_length = self.payload_length(); + payload_length + alloy_rlp::length_of_length(payload_length) + } +} + +/// Provides decoding functionalities for the [MptNode] type. +/// +/// This implementation allows for the deserialization of an RLP-encoded [MptNode] back +/// into its original form. The decoding is done based on the prototype of the RLP data, +/// ensuring that the node is reconstructed accurately. +impl Decodable for MptNode { + #[inline] + fn decode(buf: &mut &[u8]) -> alloy_rlp::Result { + let mut items = match Header::decode_raw(buf)? { + alloy_rlp::PayloadView::List(list) => list, + alloy_rlp::PayloadView::String(val) => { + return if val.is_empty() { + Ok(MptNodeData::Null.into()) + } else if val.len() == 32 { + Ok(MptNodeData::Digest(B256::from_slice(val)).into()) + } else { + println!("Invalid digest length: {val:?} {}", val.len()); + Err(alloy_rlp::Error::Custom("invalid digest")) + }; + } + }; + + // A valid number of trie node items is either 17 (branch node) + // or 2 (extension or leaf node). + match items.len() { + 17 => { + let mut node_list = Vec::with_capacity(16); + for item in items.iter().take(16) { + if *item == [EMPTY_STRING_CODE] { + node_list.push(None); + } else { + node_list.push(Some(Box::new(MptNode::decode(&mut &**item)?))); + } + } + if items[16] != [EMPTY_STRING_CODE] { + return Err(alloy_rlp::Error::Custom("branch node values are not supported")); + } + Ok(MptNodeData::Branch(node_list.try_into().unwrap()).into()) + } + 2 => { + let path = Header::decode_bytes(&mut &*items[0], false)?; + let prefix = path[0]; + if (prefix & (2 << 4)) == 0 { + let node = MptNode::decode(&mut items[1])?; + Ok(MptNodeData::Extension(path.to_vec(), Box::new(node)).into()) + } else { + let value = Header::decode_bytes(&mut &*items[1], false)?; + Ok(MptNodeData::Leaf(path.to_vec(), value.to_vec()).into()) + } + } + _ => Err(alloy_rlp::Error::Custom("invalid number of items in the list")), + } + } +} + +/// Represents a node in the sparse Merkle Patricia Trie (MPT). +/// +/// The [MptNode] type encapsulates the data and functionalities associated with a node in +/// the MPT. It provides methods for manipulating the trie, such as inserting, deleting, +/// and retrieving values, as well as utility methods for encoding, decoding, and +/// debugging. +impl MptNode { + /// Clears the trie, replacing its data with an empty node, [MptNodeData::Null]. + /// + /// This method effectively removes all key-value pairs from the trie. + #[inline] + pub fn clear(&mut self) { + self.data = MptNodeData::Null; + self.invalidate_ref_cache(); + } + + /// Retrieves the underlying data of the node. + /// + /// This method provides a reference to the node's data, allowing for inspection and + /// manipulation. + #[inline] + pub fn as_data(&self) -> &MptNodeData { + &self.data + } + + /// Retrieves the [MptNodeReference] reference of the node when it's referenced inside + /// another node. + /// + /// This method provides a way to obtain a compact representation of the node for + /// storage or transmission purposes. + #[inline] + pub fn reference(&self) -> MptNodeReference { + self.cached_reference.lock().unwrap().get_or_insert_with(|| self.calc_reference()).clone() + } + + /// Transforms a sequence of nibbles into an encoded path. + pub fn for_each_leaves(&self, mut f: F) { + let _ = self.try_for_each_leaves(|k, v| { + f(k, v); + Ok::<(), Infallible>(()) + }); + } + + /// Transforms a sequence of nibbles into an encoded path. + pub fn try_for_each_leaves Result<(), E>>( + &self, + mut f: F, + ) -> Result<(), E> { + let mut stack = vec![(self, Nibbles::default())]; + + while let Some((node, path)) = stack.pop() { + match node.as_data() { + MptNodeData::Null | MptNodeData::Digest(_) => (), + MptNodeData::Branch(branch) => { + for (i, n) in + branch.iter().enumerate().filter_map(|(i, n)| n.as_ref().map(|n| (i, n))) + { + let mut new_path = path; + new_path.push(i as u8); + stack.push((n, new_path)); + } + } + MptNodeData::Leaf(prefix, value) => { + let mut full_path = path; + full_path.extend(&Nibbles::from_nibbles(prefix_nibs(prefix))); + f(&full_path.pack(), value)?; + } + MptNodeData::Extension(prefix, node) => { + let mut new_path = path; + new_path.extend(&Nibbles::from_nibbles(prefix_nibs(prefix))); + stack.push((node, new_path)); + } + } + } + + Ok(()) + } + + /// Computes and returns the 256-bit hash of the node. + /// + /// This method provides a unique identifier for the node based on its content. + #[inline] + pub fn hash(&self) -> B256 { + match self.data { + MptNodeData::Null => EMPTY_ROOT_HASH, + _ => match self.reference() { + MptNodeReference::Digest(digest) => digest, + MptNodeReference::Bytes(bytes) => keccak256(bytes), + }, + } + } + + /// Encodes the [MptNodeReference] of this node into the `out` buffer. + fn reference_encode(&self, out: &mut dyn alloy_rlp::BufMut) { + match self.reference() { + // if the reference is an RLP-encoded byte slice, copy it directly + MptNodeReference::Bytes(bytes) => out.put_slice(&bytes), + // if the reference is a digest, RLP-encode it with its fixed known length + MptNodeReference::Digest(digest) => { + out.put_u8(alloy_rlp::EMPTY_STRING_CODE + 32); + out.put_slice(digest.as_slice()); + } + } + } + + /// Returns the length of the encoded [MptNodeReference] of this node. + fn reference_length(&self) -> usize { + match self.reference() { + MptNodeReference::Bytes(bytes) => bytes.len(), + MptNodeReference::Digest(_) => 1 + 32, + } + } + + fn calc_reference(&self) -> MptNodeReference { + match &self.data { + MptNodeData::Null => MptNodeReference::Bytes(vec![EMPTY_STRING_CODE]), + MptNodeData::Digest(digest) => MptNodeReference::Digest(*digest), + _ => { + let encoded = alloy_rlp::encode(self); + if encoded.len() < 32 { + MptNodeReference::Bytes(encoded) + } else { + MptNodeReference::Digest(keccak256(encoded)) + } + } + } + } + + /// Determines if the trie is empty. + /// + /// This method checks if the node represents an empty trie, i.e., it doesn't contain + /// any key-value pairs. + #[inline] + pub fn is_empty(&self) -> bool { + matches!(&self.data, MptNodeData::Null) + } + + /// Determines if the node represents a digest. + /// + /// A digest is a compact representation of a sub-trie, represented by its hash. + #[inline] + pub fn is_digest(&self) -> bool { + matches!(&self.data, MptNodeData::Digest(_)) + } + + /// Retrieves the value associated with a given key in the trie. + /// + /// If the key is not present in the trie, this method returns `None`. Otherwise, it + /// returns a reference to the associated value. If [None] is returned, the key is + /// provably not in the trie. + #[inline] + pub fn get(&self, key: &[u8]) -> Result, ProviderError> { + self.get_internal(&to_nibs(key)) + } + + /// Retrieves the RLP-decoded value corresponding to the key. + /// + /// If the key is not present in the trie, this method returns `None`. Otherwise, it + /// returns the RLP-decoded value. + #[inline] + pub fn get_rlp(&self, key: &[u8]) -> Result, ProviderError> { + match self.get(key)? { + Some(mut bytes) => Ok(Some(T::decode(&mut bytes)?)), + None => Ok(None), + } + } + + fn get_internal(&self, key_nibs: &[u8]) -> Result, ProviderError> { + match &self.data { + MptNodeData::Null => Ok(None), + MptNodeData::Branch(nodes) => { + if let Some((i, tail)) = key_nibs.split_first() { + match nodes[*i as usize] { + Some(ref node) => node.get_internal(tail), + None => Ok(None), + } + } else { + Ok(None) + } + } + MptNodeData::Leaf(prefix, value) => { + if prefix_nibs(prefix) == key_nibs { + Ok(Some(value)) + } else { + Ok(None) + } + } + MptNodeData::Extension(prefix, node) => { + if let Some(tail) = key_nibs.strip_prefix(prefix_nibs(prefix).as_slice()) { + node.get_internal(tail) + } else { + Ok(None) + } + } + MptNodeData::Digest(digest) => { + Err(ProviderError::TrieWitnessError(format!("incomplete witness for {digest:?}"))) + } + } + } + + /// Removes a key from the trie. + /// + /// This method attempts to remove a key-value pair from the trie. If the key is + /// present, it returns `true`. Otherwise, it returns `false`. + #[inline] + pub fn delete(&mut self, key: &[u8]) -> Result { + self.delete_internal(&to_nibs(key)) + } + + fn delete_internal(&mut self, key_nibs: &[u8]) -> Result { + match &mut self.data { + MptNodeData::Null => return Ok(false), + MptNodeData::Branch(children) => { + if let Some((i, tail)) = key_nibs.split_first() { + let child = &mut children[*i as usize]; + match child { + Some(node) => { + if !node.delete_internal(tail)? { + return Ok(false); + } + // if the node is now empty, remove it + if node.is_empty() { + *child = None; + } + } + None => return Ok(false), + } + } else { + return Err(ProviderError::TrieWitnessError("value in branch".to_string())); + } + + let mut remaining = children.iter_mut().enumerate().filter(|(_, n)| n.is_some()); + // there will always be at least one remaining node + let (index, node) = remaining.next().unwrap(); + // if there is only exactly one node left, we need to convert the branch + if remaining.next().is_none() { + let mut orphan = node.take().unwrap(); + match &mut orphan.data { + // if the orphan is a leaf, prepend the corresponding nib to it + MptNodeData::Leaf(prefix, orphan_value) => { + let new_nibs: Vec<_> = + iter::once(index as u8).chain(prefix_nibs(prefix)).collect(); + self.data = MptNodeData::Leaf( + to_encoded_path(&new_nibs, true), + mem::take(orphan_value), + ); + } + // if the orphan is an extension, prepend the corresponding nib to it + MptNodeData::Extension(prefix, orphan_child) => { + let new_nibs: Vec<_> = + iter::once(index as u8).chain(prefix_nibs(prefix)).collect(); + self.data = MptNodeData::Extension( + to_encoded_path(&new_nibs, false), + mem::take(orphan_child), + ); + } + // if the orphan is a branch or digest, convert to an extension + MptNodeData::Branch(_) | MptNodeData::Digest(_) => { + self.data = MptNodeData::Extension( + to_encoded_path(&[index as u8], false), + orphan, + ); + } + MptNodeData::Null => unreachable!(), + } + } + } + MptNodeData::Leaf(prefix, _) => { + if prefix_nibs(prefix) != key_nibs { + return Ok(false); + } + self.data = MptNodeData::Null; + } + MptNodeData::Extension(prefix, child) => { + let mut self_nibs = prefix_nibs(prefix); + if let Some(tail) = key_nibs.strip_prefix(self_nibs.as_slice()) { + if !child.delete_internal(tail)? { + return Ok(false); + } + } else { + return Ok(false); + } + + // an extension can only point to a branch or a digest; since it's sub trie was + // modified, we need to make sure that this property still holds + match &mut child.data { + // if the child is empty, remove the extension + MptNodeData::Null => { + self.data = MptNodeData::Null; + } + // for a leaf, replace the extension with the extended leaf + MptNodeData::Leaf(prefix, value) => { + self_nibs.extend(prefix_nibs(prefix)); + self.data = + MptNodeData::Leaf(to_encoded_path(&self_nibs, true), mem::take(value)); + } + // for an extension, replace the extension with the extended extension + MptNodeData::Extension(prefix, node) => { + self_nibs.extend(prefix_nibs(prefix)); + self.data = MptNodeData::Extension( + to_encoded_path(&self_nibs, false), + mem::take(node), + ); + } + // for a branch or digest, the extension is still correct + MptNodeData::Branch(_) | MptNodeData::Digest(_) => {} + } + } + MptNodeData::Digest(digest) => { + return Err(ProviderError::TrieWitnessError(format!( + "incomplete witness for {digest:?}" + ))) + } + }; + + self.invalidate_ref_cache(); + Ok(true) + } + + /// Inserts an RLP-encoded value into the trie. + /// + /// This method inserts a value that's been encoded using RLP into the trie. + #[inline] + pub fn insert_rlp(&mut self, key: &[u8], value: impl Encodable) -> Result { + self.insert_internal(&to_nibs(key), value.to_rlp()) + } + + fn insert_internal(&mut self, key_nibs: &[u8], value: Vec) -> Result { + match &mut self.data { + MptNodeData::Null => { + self.data = MptNodeData::Leaf(to_encoded_path(key_nibs, true), value); + } + MptNodeData::Branch(children) => { + if let Some((i, tail)) = key_nibs.split_first() { + let child = &mut children[*i as usize]; + match child { + Some(node) => { + if !node.insert_internal(tail, value)? { + return Ok(false); + } + } + // if the corresponding child is empty, insert a new leaf + None => { + *child = Some(Box::new( + MptNodeData::Leaf(to_encoded_path(tail, true), value).into(), + )); + } + } + } else { + return Err(ProviderError::TrieWitnessError("value in branch".to_string())); + } + } + MptNodeData::Leaf(prefix, old_value) => { + let self_nibs = prefix_nibs(prefix); + let common_len = lcp(&self_nibs, key_nibs); + if common_len == self_nibs.len() && common_len == key_nibs.len() { + // if self_nibs == key_nibs, update the value if it is different + if old_value == &value { + return Ok(false); + } + *old_value = value; + } else if common_len == self_nibs.len() || common_len == key_nibs.len() { + return Err(ProviderError::TrieWitnessError("value in branch".to_string())); + } else { + let split_point = common_len + 1; + // otherwise, create a branch with two children + let mut children: [Option>; 16] = Default::default(); + + children[self_nibs[common_len] as usize] = Some(Box::new( + MptNodeData::Leaf( + to_encoded_path(&self_nibs[split_point..], true), + mem::take(old_value), + ) + .into(), + )); + children[key_nibs[common_len] as usize] = Some(Box::new( + MptNodeData::Leaf(to_encoded_path(&key_nibs[split_point..], true), value) + .into(), + )); + + let branch = MptNodeData::Branch(children); + if common_len > 0 { + // create parent extension for new branch + self.data = MptNodeData::Extension( + to_encoded_path(&self_nibs[..common_len], false), + Box::new(branch.into()), + ); + } else { + self.data = branch; + } + } + } + MptNodeData::Extension(prefix, existing_child) => { + let self_nibs = prefix_nibs(prefix); + let common_len = lcp(&self_nibs, key_nibs); + if common_len == self_nibs.len() { + // traverse down for update + if !existing_child.insert_internal(&key_nibs[common_len..], value)? { + return Ok(false); + } + } else if common_len == key_nibs.len() { + return Err(ProviderError::TrieWitnessError("value in branch".to_string())); + } else { + let split_point = common_len + 1; + // otherwise, create a branch with two children + let mut children: [Option>; 16] = Default::default(); + + children[self_nibs[common_len] as usize] = if split_point < self_nibs.len() { + Some(Box::new( + MptNodeData::Extension( + to_encoded_path(&self_nibs[split_point..], false), + mem::take(existing_child), + ) + .into(), + )) + } else { + Some(mem::take(existing_child)) + }; + children[key_nibs[common_len] as usize] = Some(Box::new( + MptNodeData::Leaf(to_encoded_path(&key_nibs[split_point..], true), value) + .into(), + )); + + let branch = MptNodeData::Branch(children); + if common_len > 0 { + // Create parent extension for new branch + self.data = MptNodeData::Extension( + to_encoded_path(&self_nibs[..common_len], false), + Box::new(branch.into()), + ); + } else { + self.data = branch; + } + } + } + MptNodeData::Digest(digest) => { + return Err(ProviderError::TrieWitnessError(format!( + "incomplete witness for {digest:?}" + ))) + } + }; + + self.invalidate_ref_cache(); + Ok(true) + } + + fn invalidate_ref_cache(&mut self) { + self.cached_reference.lock().unwrap().take(); + } + + /// Returns the length of the RLP payload of the node. + fn payload_length(&self) -> usize { + match &self.data { + MptNodeData::Null => 0, + MptNodeData::Branch(nodes) => { + 1 + nodes + .iter() + .map(|child| child.as_ref().map_or(1, |node| node.reference_length())) + .sum::() + } + MptNodeData::Leaf(prefix, value) => { + prefix.as_slice().length() + value.as_slice().length() + } + MptNodeData::Extension(prefix, node) => { + prefix.as_slice().length() + node.reference_length() + } + MptNodeData::Digest(_) => 32, + } + } +} + +/// Converts a byte slice into a vector of nibbles. +/// +/// A nibble is 4 bits or half of an 8-bit byte. This function takes each byte from the +/// input slice, splits it into two nibbles, and appends them to the resulting vector. +fn to_nibs(slice: &[u8]) -> Vec { + let mut result = Vec::with_capacity(2 * slice.len()); + for byte in slice { + result.push(byte >> 4); + result.push(byte & 0xf); + } + result +} + +/// Encodes a slice of nibbles into a vector of bytes, with an additional prefix to +/// indicate the type of node (leaf or extension). +/// +/// The function starts by determining the type of node based on the `is_leaf` parameter. +/// If the node is a leaf, the prefix is set to `0x20`. If the length of the nibbles is +/// odd, the prefix is adjusted and the first nibble is incorporated into it. +/// +/// The remaining nibbles are then combined into bytes, with each pair of nibbles forming +/// a single byte. The resulting vector starts with the prefix, followed by the encoded +/// bytes. +fn to_encoded_path(mut nibs: &[u8], is_leaf: bool) -> Vec { + let mut prefix = (is_leaf as u8) * 0x20; + if nibs.len() % 2 != 0 { + prefix += 0x10 + nibs[0]; + nibs = &nibs[1..]; + } + iter::once(prefix).chain(nibs.chunks_exact(2).map(|byte| (byte[0] << 4) + byte[1])).collect() +} + +/// Returns the length of the common prefix. +fn lcp(a: &[u8], b: &[u8]) -> usize { + for (i, (a, b)) in iter::zip(a, b).enumerate() { + if a != b { + return i; + } + } + cmp::min(a.len(), b.len()) +} + +fn prefix_nibs(prefix: &[u8]) -> Vec { + let (extension, tail) = prefix.split_first().unwrap(); + // the first bit of the first nibble denotes the parity + let is_odd = extension & (1 << 4) != 0; + + let mut result = Vec::with_capacity(2 * tail.len() + is_odd as usize); + // for odd lengths, the second nibble contains the first element + if is_odd { + result.push(extension & 0xf); + } + for nib in tail { + result.push(nib >> 4); + result.push(nib & 0xf); + } + result +} + +/// Creates a new MPT trie where all the digests contained in `node_store` are resolved. +pub(crate) fn resolve_nodes( + root: &MptNode, + node_store: &HashMap, +) -> MptNode { + let trie = match root.as_data() { + MptNodeData::Null | MptNodeData::Leaf(_, _) => root.clone(), + MptNodeData::Branch(children) => { + let children: Vec<_> = children + .iter() + .map(|child| child.as_ref().map(|node| Box::new(resolve_nodes(node, node_store)))) + .collect(); + MptNodeData::Branch(children.try_into().unwrap()).into() + } + MptNodeData::Extension(prefix, target) => { + MptNodeData::Extension(prefix.clone(), Box::new(resolve_nodes(target, node_store))) + .into() + } + MptNodeData::Digest(digest) => { + if let Some(node) = node_store.get(&MptNodeReference::Digest(*digest)) { + resolve_nodes(node, node_store) + } else { + root.clone() + } + } + }; + // the root hash must not change + debug_assert_eq!(root.hash(), trie.hash()); + + trie +} + +#[cfg(test)] +mod tests { + use super::*; + use alloy_primitives::hex; + + #[test] + fn test_trie_pointer_no_keccak() { + let cases = [("do", "verb"), ("dog", "puppy"), ("doge", "coin"), ("horse", "stallion")]; + for (k, v) in cases { + let node: MptNode = + MptNodeData::Leaf(k.as_bytes().to_vec(), v.as_bytes().to_vec()).into(); + assert!( + matches!(node.reference(),MptNodeReference::Bytes(bytes) if bytes == node.to_rlp().to_vec()) + ); + } + } + + #[test] + fn test_to_encoded_path() { + // extension node with an even path length + let nibbles = vec![0x0a, 0x0b, 0x0c, 0x0d]; + assert_eq!(to_encoded_path(&nibbles, false), vec![0x00, 0xab, 0xcd]); + // extension node with an odd path length + let nibbles = vec![0x0a, 0x0b, 0x0c]; + assert_eq!(to_encoded_path(&nibbles, false), vec![0x1a, 0xbc]); + // leaf node with an even path length + let nibbles = vec![0x0a, 0x0b, 0x0c, 0x0d]; + assert_eq!(to_encoded_path(&nibbles, true), vec![0x20, 0xab, 0xcd]); + // leaf node with an odd path length + let nibbles = vec![0x0a, 0x0b, 0x0c]; + assert_eq!(to_encoded_path(&nibbles, true), vec![0x3a, 0xbc]); + } + + #[test] + fn test_lcp() { + let cases = [ + (vec![], vec![], 0), + (vec![0xa], vec![0xa], 1), + (vec![0xa, 0xb], vec![0xa, 0xc], 1), + (vec![0xa, 0xb], vec![0xa, 0xb], 2), + (vec![0xa, 0xb], vec![0xa, 0xb, 0xc], 2), + (vec![0xa, 0xb, 0xc], vec![0xa, 0xb, 0xc], 3), + (vec![0xa, 0xb, 0xc], vec![0xa, 0xb, 0xc, 0xd], 3), + (vec![0xa, 0xb, 0xc, 0xd], vec![0xa, 0xb, 0xc, 0xd], 4), + ]; + for (a, b, cpl) in cases { + assert_eq!(lcp(&a, &b), cpl) + } + } + + #[test] + fn test_empty() { + let trie = MptNode::default(); + + assert!(trie.is_empty()); + assert_eq!(trie.reference(), MptNodeReference::Bytes(vec![0x80])); + let expected = hex!("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421"); + assert_eq!(expected, trie.hash().0); + + // test RLP encoding + let mut out = Vec::new(); + trie.encode(&mut out); + assert_eq!(out, vec![0x80]); + assert_eq!(trie.length(), out.len()); + let decoded = MptNode::decode(&mut &*out).unwrap(); + assert_eq!(trie.hash(), decoded.hash()); + } + + #[test] + fn test_tiny() { + // trie consisting of an extension, a branch and two leafs + let mut trie = MptNode::default(); + trie.insert_rlp(b"a", 0u8).unwrap(); + trie.insert_rlp(b"b", 1u8).unwrap(); + + assert!(!trie.is_empty()); + let exp_rlp = hex!("d816d680c3208180c220018080808080808080808080808080"); + assert_eq!(trie.reference(), MptNodeReference::Bytes(exp_rlp.to_vec())); + let exp_hash = hex!("6fbf23d6ec055dd143ff50d558559770005ff44ae1d41276f1bd83affab6dd3b"); + assert_eq!(trie.hash().0, exp_hash); + + // test RLP encoding + let mut out = Vec::new(); + trie.encode(&mut out); + assert_eq!(out, exp_rlp.to_vec()); + assert_eq!(trie.length(), out.len()); + let decoded = MptNode::decode(&mut &*out).unwrap(); + assert_eq!(trie.hash(), decoded.hash()); + } + + #[test] + fn test_partial() { + let mut trie = MptNode::default(); + trie.insert_rlp(b"aa", 0u8).unwrap(); + trie.insert_rlp(b"ab", 1u8).unwrap(); + trie.insert_rlp(b"ba", 2u8).unwrap(); + + let exp_hash = trie.hash(); + + // replace one node with its digest + let MptNodeData::Extension(_, node) = &mut trie.data else { panic!("extension expected") }; + **node = MptNodeData::Digest(node.hash()).into(); + assert!(node.is_digest()); + + let trie = MptNode::decode(&mut &*trie.to_rlp()).unwrap(); + assert_eq!(trie.hash(), exp_hash); + + // lookups should fail + trie.get(b"aa").unwrap_err(); + trie.get(b"a0").unwrap_err(); + } + + #[test] + fn test_keccak_trie() { + const N: usize = 512; + + // insert + let mut trie = MptNode::default(); + for i in 0..N { + assert!(trie.insert_rlp(&*keccak256(i.to_be_bytes()), i).unwrap()); + + // check hash against trie build in reverse + let mut reference = MptNode::default(); + for j in (0..=i).rev() { + reference.insert_rlp(&*keccak256(j.to_be_bytes()), j).unwrap(); + } + assert_eq!(trie.hash(), reference.hash()); + } + + let expected = hex!("7310027edebdd1f7c950a7fb3413d551e85dff150d45aca4198c2f6315f9b4a7"); + assert_eq!(trie.hash().0, expected); + + // get + for i in 0..N { + assert_eq!(trie.get_rlp(&*keccak256(i.to_be_bytes())).unwrap(), Some(i)); + assert!(trie.get(&*keccak256((i + N).to_be_bytes())).unwrap().is_none()); + } + + // delete + for i in 0..N { + assert!(trie.delete(&*keccak256(i.to_be_bytes())).unwrap()); + + let mut reference = MptNode::default(); + for j in ((i + 1)..N).rev() { + reference.insert_rlp(&*keccak256(j.to_be_bytes()), j).unwrap(); + } + assert_eq!(trie.hash(), reference.hash()); + } + assert!(trie.is_empty()); + } + + #[test] + fn test_index_trie() { + const N: usize = 512; + + // insert + let mut trie = MptNode::default(); + for i in 0..N { + assert!(trie.insert_rlp(&i.to_rlp(), i).unwrap()); + + // check hash against trie build in reverse + let mut reference = MptNode::default(); + for j in (0..=i).rev() { + reference.insert_rlp(&j.to_rlp(), j).unwrap(); + } + assert_eq!(trie.hash(), reference.hash()); + + // try RLP roundtrip + let decoded = MptNode::decode(&mut &*trie.to_rlp()).unwrap(); + assert_eq!(trie.hash(), decoded.hash()); + } + + // get + for i in 0..N { + assert_eq!(trie.get_rlp(&i.to_rlp()).unwrap(), Some(i)); + assert!(trie.get(&(i + N).to_rlp()).unwrap().is_none()); + } + + // delete + for i in 0..N { + assert!(trie.delete(&i.to_rlp()).unwrap()); + + let mut reference = MptNode::default(); + for j in ((i + 1)..N).rev() { + reference.insert_rlp(&j.to_rlp(), j).unwrap(); + } + assert_eq!(trie.hash(), reference.hash()); + } + assert!(trie.is_empty()); + } +}