diff --git a/Cargo.lock b/Cargo.lock index 0ab63ba8e51c..c6d679125981 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6156,6 +6156,15 @@ dependencies = [ "tracing", ] +[[package]] +name = "reth-etl" +version = "0.1.0-alpha.16" +dependencies = [ + "reth-db", + "reth-primitives", + "tempfile", +] + [[package]] name = "reth-interfaces" version = "0.1.0-alpha.16" @@ -6856,6 +6865,7 @@ dependencies = [ "reth-db", "reth-downloaders", "reth-eth-wire", + "reth-etl", "reth-interfaces", "reth-metrics", "reth-primitives", @@ -6866,6 +6876,7 @@ dependencies = [ "revm", "serde", "serde_json", + "tempfile", "thiserror", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index 09fa887b68ba..b2448bfef388 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/consensus/beacon-core/", "crates/consensus/common/", "crates/ethereum-forks/", + "crates/etl", "crates/interfaces/", "crates/metrics/", "crates/metrics/metrics-derive/", @@ -129,6 +130,7 @@ reth-ecies = { path = "crates/net/ecies" } reth-eth-wire = { path = "crates/net/eth-wire" } reth-ethereum-forks = { path = "crates/ethereum-forks" } reth-ethereum-payload-builder = { path = "crates/payload/ethereum" } +reth-etl = { path = "crates/etl" } reth-optimism-payload-builder = { path = "crates/payload/optimism" } reth-interfaces = { path = "crates/interfaces" } reth-ipc = { path = "crates/rpc/ipc" } diff --git a/bin/reth/src/commands/debug_cmd/execution.rs b/bin/reth/src/commands/debug_cmd/execution.rs index 6dd9e2d7ef27..c72782ebead0 100644 --- a/bin/reth/src/commands/debug_cmd/execution.rs +++ b/bin/reth/src/commands/debug_cmd/execution.rs @@ -32,7 +32,7 @@ use reth_primitives::{fs, stage::StageId, BlockHashOrNumber, BlockNumber, ChainS use reth_provider::{BlockExecutionWriter, HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ sets::DefaultStages, - stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage}, Pipeline, StageSet, }; use reth_tasks::TaskExecutor; @@ -123,11 +123,7 @@ impl Command { header_downloader, body_downloader, factory.clone(), - ) - .set( - TotalDifficultyStage::new(consensus) - .with_commit_threshold(stage_conf.total_difficulty.commit_threshold), - ) + )? .set(SenderRecoveryStage { commit_threshold: stage_conf.sender_recovery.commit_threshold, }) diff --git a/bin/reth/src/commands/import.rs b/bin/reth/src/commands/import.rs index 7c57e6c0a498..cc1137961cda 100644 --- a/bin/reth/src/commands/import.rs +++ b/bin/reth/src/commands/import.rs @@ -20,7 +20,7 @@ use reth_primitives::{stage::StageId, ChainSpec, B256}; use reth_provider::{HeaderSyncMode, ProviderFactory, StageCheckpointReader}; use reth_stages::{ prelude::*, - stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage, TotalDifficultyStage}, + stages::{ExecutionStage, ExecutionStageThresholds, SenderRecoveryStage}, }; use std::{path::PathBuf, sync::Arc}; use tokio::sync::watch; @@ -173,11 +173,7 @@ impl ImportCommand { header_downloader, body_downloader, factory.clone(), - ) - .set( - TotalDifficultyStage::new(consensus.clone()) - .with_commit_threshold(config.stages.total_difficulty.commit_threshold), - ) + )? .set(SenderRecoveryStage { commit_threshold: config.stages.sender_recovery.commit_threshold, }) diff --git a/crates/consensus/beacon/src/engine/test_utils.rs b/crates/consensus/beacon/src/engine/test_utils.rs index 1d0d49e5336d..2c4f8b14daf9 100644 --- a/crates/consensus/beacon/src/engine/test_utils.rs +++ b/crates/consensus/beacon/src/engine/test_utils.rs @@ -492,14 +492,17 @@ where .build(client.clone(), consensus.clone(), provider_factory.clone()) .into_task(); - Pipeline::builder().add_stages(DefaultStages::new( - ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()), - HeaderSyncMode::Tip(tip_rx.clone()), - Arc::clone(&consensus), - header_downloader, - body_downloader, - executor_factory.clone(), - )) + Pipeline::builder().add_stages( + DefaultStages::new( + ProviderFactory::new(db.clone(), self.base_config.chain_spec.clone()), + HeaderSyncMode::Tip(tip_rx.clone()), + Arc::clone(&consensus), + header_downloader, + body_downloader, + executor_factory.clone(), + ) + .expect("should build"), + ) } }; diff --git a/crates/etl/Cargo.toml b/crates/etl/Cargo.toml new file mode 100644 index 000000000000..5b7724f03ae8 --- /dev/null +++ b/crates/etl/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "reth-etl" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true +exclude.workspace = true + +[dependencies] +tempfile.workspace = true +reth-db.workspace = true + +[dev-dependencies] +reth-primitives.workspace = true diff --git a/crates/etl/src/lib.rs b/crates/etl/src/lib.rs new file mode 100644 index 000000000000..27814d4ae504 --- /dev/null +++ b/crates/etl/src/lib.rs @@ -0,0 +1,264 @@ +//! ETL data collector. +//! +//! This crate is useful for dumping unsorted data into temporary files and iterating on their +//! sorted representation later on. +//! +//! This has multiple uses, such as optimizing database inserts (for Btree based databases) and +//! memory management (as it moves the buffer to disk instead of memory). + +#![doc( + html_logo_url = "https://raw.githubusercontent.com/paradigmxyz/reth/main/assets/reth-docs.png", + html_favicon_url = "https://avatars0.githubusercontent.com/u/97369466?s=256", + issue_tracker_base_url = "https://github.com/paradigmxyz/reth/issues/" +)] +#![warn(missing_debug_implementations, missing_docs, unreachable_pub, rustdoc::all)] +#![deny(unused_must_use, rust_2018_idioms)] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] + +use std::{ + cmp::Reverse, + collections::BinaryHeap, + io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}, + path::Path, + sync::Arc, +}; + +use reth_db::table::{Compress, Encode, Key, Value}; +use tempfile::{NamedTempFile, TempDir}; + +/// An ETL (extract, transform, load) data collector. +/// +/// Data is pushed (extract) to the collector which internally flushes the data in a sorted +/// (transform) manner to files of some specified capacity. +/// +/// The data can later be iterated over (load) in a sorted manner. +#[derive(Debug)] +pub struct Collector +where + K: Encode + Ord, + V: Compress, + ::Encoded: std::fmt::Debug, + ::Compressed: std::fmt::Debug, +{ + /// Directory for temporary file storage + dir: Arc, + /// Collection of temporary ETL files + files: Vec, + /// Current buffer size in bytes + buffer_size_bytes: usize, + /// Maximum buffer capacity in bytes, triggers flush when reached + buffer_capacity_bytes: usize, + /// In-memory buffer storing encoded and compressed key-value pairs + buffer: Vec<(::Encoded, ::Compressed)>, + /// Total number of elements in the collector, including all files + len: usize, +} + +impl Collector +where + K: Key, + V: Value, + ::Encoded: Ord + std::fmt::Debug, + ::Compressed: Ord + std::fmt::Debug, +{ + /// Create a new collector in a specific temporary directory with some capacity. + /// + /// Once the capacity (in bytes) is reached, the data is sorted and flushed to disk. + pub fn new(dir: Arc, buffer_capacity_bytes: usize) -> Self { + Self { + dir, + buffer_size_bytes: 0, + files: Vec::new(), + buffer_capacity_bytes, + buffer: Vec::new(), + len: 0, + } + } + + /// Returns number of elements currently in the collector. + pub fn len(&self) -> usize { + self.len + } + + /// Returns `true` if there are currently no elements in the collector. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Insert an entry into the collector. + pub fn insert(&mut self, key: K, value: V) { + let key = key.encode(); + let value = value.compress(); + self.buffer_size_bytes += key.as_ref().len() + value.as_ref().len(); + self.buffer.push((key, value)); + if self.buffer_size_bytes > self.buffer_capacity_bytes { + self.flush(); + } + self.len += 1; + } + + fn flush(&mut self) { + self.buffer_size_bytes = 0; + self.buffer.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + let mut buf = Vec::with_capacity(self.buffer.len()); + std::mem::swap(&mut buf, &mut self.buffer); + self.files.push(EtlFile::new(self.dir.path(), buf).expect("could not flush data to disk")) + } + + /// Returns an iterator over the collector data. + /// + /// The items of the iterator are sorted across all underlying files. + /// + /// # Note + /// + /// The keys and values have been pre-encoded, meaning they *SHOULD NOT* be encoded or + /// compressed again. + pub fn iter(&mut self) -> std::io::Result> { + // Flush the remaining items to disk + if self.buffer_size_bytes > 0 { + self.flush(); + } + + let mut heap = BinaryHeap::new(); + for (current_id, file) in self.files.iter_mut().enumerate() { + if let Some((current_key, current_value)) = file.read_next()? { + heap.push((Reverse((current_key, current_value)), current_id)); + } + } + + Ok(EtlIter { heap, files: &mut self.files }) + } +} + +/// `EtlIter` is an iterator for traversing through sorted key-value pairs in a collection of ETL +/// files. These files are created using the [`Collector`] and contain data where keys are encoded +/// and values are compressed. +/// +/// This iterator returns each key-value pair in ascending order based on the key. +/// It is particularly designed to efficiently handle large datasets by employing a binary heap for +/// managing the iteration order. +#[derive(Debug)] +pub struct EtlIter<'a> { + /// Heap managing the next items to be iterated. + #[allow(clippy::type_complexity)] + heap: BinaryHeap<(Reverse<(Vec, Vec)>, usize)>, + /// Reference to the vector of ETL files being iterated over. + files: &'a mut Vec, +} + +impl<'a> EtlIter<'a> { + /// Peeks into the next element + pub fn peek(&self) -> Option<&(Vec, Vec)> { + self.heap.peek().map(|(Reverse(entry), _)| entry) + } +} + +impl<'a> Iterator for EtlIter<'a> { + type Item = std::io::Result<(Vec, Vec)>; + + fn next(&mut self) -> Option { + // Get the next sorted entry from the heap + let (Reverse(entry), id) = self.heap.pop()?; + + // Populate the heap with the next entry from the same file + match self.files[id].read_next() { + Ok(Some((key, value))) => { + self.heap.push((Reverse((key, value)), id)); + Some(Ok(entry)) + } + Ok(None) => Some(Ok(entry)), + err => err.transpose(), + } + } +} + +/// A temporary ETL file. +#[derive(Debug)] +struct EtlFile { + file: BufReader, + len: usize, +} + +impl EtlFile { + /// Create a new file with the given data (which should be pre-sorted) at the given path. + /// + /// The file will be a temporary file. + pub(crate) fn new(dir: &Path, buffer: Vec<(K, V)>) -> std::io::Result + where + Self: Sized, + K: AsRef<[u8]>, + V: AsRef<[u8]>, + { + let file = NamedTempFile::new_in(dir)?; + let mut w = BufWriter::new(file); + for entry in &buffer { + let k = entry.0.as_ref(); + let v = entry.1.as_ref(); + + w.write_all(&k.len().to_be_bytes())?; + w.write_all(&v.len().to_be_bytes())?; + w.write_all(k)?; + w.write_all(v)?; + } + + let mut file = BufReader::new(w.into_inner()?); + file.seek(SeekFrom::Start(0))?; + let len = buffer.len(); + Ok(Self { file, len }) + } + + /// Read the next entry in the file. + /// + /// Can return error if it reaches EOF before filling the internal buffers. + pub(crate) fn read_next(&mut self) -> std::io::Result, Vec)>> { + if self.len == 0 { + return Ok(None); + } + + let mut buffer_key_length = [0; 8]; + let mut buffer_value_length = [0; 8]; + + self.file.read_exact(&mut buffer_key_length)?; + self.file.read_exact(&mut buffer_value_length)?; + + let key_length = usize::from_be_bytes(buffer_key_length); + let value_length = usize::from_be_bytes(buffer_value_length); + let mut key = vec![0; key_length]; + let mut value = vec![0; value_length]; + + self.file.read_exact(&mut key)?; + self.file.read_exact(&mut value)?; + + self.len -= 1; + + Ok(Some((key, value))) + } +} + +#[cfg(test)] +mod tests { + use reth_primitives::{TxHash, TxNumber}; + use tempfile::TempDir; + + use super::*; + + #[test] + fn etl_hashes() { + let mut entries: Vec<_> = + (0..10_000).map(|id| (TxHash::random(), id as TxNumber)).collect(); + + let mut collector = Collector::new(Arc::new(TempDir::new().unwrap()), 1024); + for (k, v) in entries.clone() { + collector.insert(k, v); + } + entries.sort_unstable_by_key(|entry| entry.0); + + for (id, entry) in collector.iter().unwrap().enumerate() { + let expected = entries[id]; + assert_eq!( + entry.unwrap(), + (expected.0.encode().to_vec(), expected.1.compress().to_vec()) + ); + } + } +} diff --git a/crates/node-core/src/node_config.rs b/crates/node-core/src/node_config.rs index 8f77a802a1ae..a6f7f48b4e65 100644 --- a/crates/node-core/src/node_config.rs +++ b/crates/node-core/src/node_config.rs @@ -86,7 +86,7 @@ use reth_stages::{ stages::{ AccountHashingStage, ExecutionStage, ExecutionStageThresholds, IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, StorageHashingStage, - TotalDifficultyStage, TransactionLookupStage, + TransactionLookupStage, }, MetricEvent, }; @@ -893,11 +893,7 @@ impl NodeConfig { header_downloader, body_downloader, factory.clone(), - ) - .set( - TotalDifficultyStage::new(consensus) - .with_commit_threshold(stage_config.total_difficulty.commit_threshold), - ) + )? .set(SenderRecoveryStage { commit_threshold: stage_config.sender_recovery.commit_threshold, }) diff --git a/crates/primitives/src/header.rs b/crates/primitives/src/header.rs index 10acec149faa..3c4049edb3d0 100644 --- a/crates/primitives/src/header.rs +++ b/crates/primitives/src/header.rs @@ -570,13 +570,14 @@ impl Decodable for Header { /// A [`Header`] that is sealed at a precalculated hash, use [`SealedHeader::unseal()`] if you want /// to modify header. -#[add_arbitrary_tests(rlp)] -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[main_codec(no_arbitrary)] +#[add_arbitrary_tests(rlp, compact)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SealedHeader { - /// Locked Header fields. - pub header: Header, /// Locked Header hash. pub hash: BlockHash, + /// Locked Header fields. + pub header: Header, } impl SealedHeader { diff --git a/crates/stages/Cargo.toml b/crates/stages/Cargo.toml index 4a3544f9bfe4..79b630972196 100644 --- a/crates/stages/Cargo.toml +++ b/crates/stages/Cargo.toml @@ -26,6 +26,7 @@ reth-codecs.workspace = true reth-provider.workspace = true reth-trie.workspace = true reth-tokio-util.workspace = true +reth-etl.workspace = true # revm revm.workspace = true @@ -42,6 +43,7 @@ tracing.workspace = true # io serde.workspace = true +tempfile.workspace = true # metrics reth-metrics.workspace = true @@ -60,12 +62,12 @@ auto_impl = "1" reth-primitives = { workspace = true, features = ["test-utils", "arbitrary"] } reth-db = { workspace = true, features = ["test-utils", "mdbx"] } reth-interfaces = { workspace = true, features = ["test-utils"] } -reth-provider = { workspace = true, features = ["test-utils"] } reth-downloaders.workspace = true reth-eth-wire.workspace = true # TODO(onbjerg): We only need this for [BlockBody] reth-blockchain-tree.workspace = true reth-revm.workspace = true reth-trie = { workspace = true, features = ["test-utils"] } +reth-provider = { workspace = true, features = ["test-utils"] } alloy-rlp.workspace = true itertools.workspace = true diff --git a/crates/stages/benches/criterion.rs b/crates/stages/benches/criterion.rs index e9354503d279..34b2f9cb6f16 100644 --- a/crates/stages/benches/criterion.rs +++ b/crates/stages/benches/criterion.rs @@ -5,10 +5,10 @@ use criterion::{ }; use pprof::criterion::{Output, PProfProfiler}; use reth_db::{test_utils::TempDatabase, DatabaseEnv}; -use reth_interfaces::test_utils::TestConsensus; + use reth_primitives::stage::StageCheckpoint; use reth_stages::{ - stages::{MerkleStage, SenderRecoveryStage, TotalDifficultyStage, TransactionLookupStage}, + stages::{MerkleStage, SenderRecoveryStage, TransactionLookupStage}, test_utils::TestStageDB, ExecInput, Stage, StageExt, UnwindInput, }; @@ -20,7 +20,7 @@ use setup::StageRange; criterion_group! { name = benches; config = Criterion::default().with_profiler(PProfProfiler::new(1000, Output::Flamegraph(None))); - targets = transaction_lookup, account_hashing, senders, total_difficulty, merkle + targets = transaction_lookup, account_hashing, senders, merkle } criterion_main!(benches); @@ -73,23 +73,6 @@ fn transaction_lookup(c: &mut Criterion) { ); } -fn total_difficulty(c: &mut Criterion) { - let mut group = c.benchmark_group("Stages"); - group.measurement_time(std::time::Duration::from_millis(2000)); - group.warm_up_time(std::time::Duration::from_millis(2000)); - // don't need to run each stage for that many times - group.sample_size(10); - let stage = TotalDifficultyStage::new(Arc::new(TestConsensus::default())); - - measure_stage( - &mut group, - setup::stage_unwind, - stage, - 0..DEFAULT_NUM_BLOCKS, - "TotalDifficulty".to_string(), - ); -} - fn merkle(c: &mut Criterion) { let mut group = c.benchmark_group("Stages"); // don't need to run each stage for that many times diff --git a/crates/stages/src/error.rs b/crates/stages/src/error.rs index d3ecfefb4798..45631b1bed24 100644 --- a/crates/stages/src/error.rs +++ b/crates/stages/src/error.rs @@ -141,6 +141,12 @@ impl StageError { } } +impl From for StageError { + fn from(source: std::io::Error) -> Self { + StageError::Fatal(Box::new(source)) + } +} + /// A pipeline execution error. #[derive(Error, Debug)] pub enum PipelineError { diff --git a/crates/stages/src/lib.rs b/crates/stages/src/lib.rs index 7c8ed234b219..80b40b3ef0da 100644 --- a/crates/stages/src/lib.rs +++ b/crates/stages/src/lib.rs @@ -44,14 +44,17 @@ //! # let pipeline = //! Pipeline::builder() //! .with_tip_sender(tip_tx) -//! .add_stages(DefaultStages::new( -//! provider_factory.clone(), -//! HeaderSyncMode::Tip(tip_rx), -//! consensus, -//! headers_downloader, -//! bodies_downloader, -//! executor_factory, -//! )) +//! .add_stages( +//! DefaultStages::new( +//! provider_factory.clone(), +//! HeaderSyncMode::Tip(tip_rx), +//! consensus, +//! headers_downloader, +//! bodies_downloader, +//! executor_factory, +//! ) +//! .unwrap(), +//! ) //! .build(provider_factory); //! ``` //! diff --git a/crates/stages/src/sets.rs b/crates/stages/src/sets.rs index 1f3d49e390b2..5878abd5c01f 100644 --- a/crates/stages/src/sets.rs +++ b/crates/stages/src/sets.rs @@ -40,9 +40,9 @@ use crate::{ stages::{ AccountHashingStage, BodyStage, ExecutionStage, FinishStage, HeaderStage, IndexAccountHistoryStage, IndexStorageHistoryStage, MerkleStage, SenderRecoveryStage, - StorageHashingStage, TotalDifficultyStage, TransactionLookupStage, + StorageHashingStage, TransactionLookupStage, }, - StageSet, StageSetBuilder, + StageError, StageSet, StageSetBuilder, }; use reth_db::database::Database; use reth_interfaces::{ @@ -51,6 +51,7 @@ use reth_interfaces::{ }; use reth_provider::{ExecutorFactory, HeaderSyncGapProvider, HeaderSyncMode}; use std::sync::Arc; +use tempfile::TempDir; /// A set containing all stages to run a fully syncing instance of reth. /// @@ -62,7 +63,6 @@ use std::sync::Arc; /// /// This expands to the following series of stages: /// - [`HeaderStage`] -/// - [`TotalDifficultyStage`] /// - [`BodyStage`] /// - [`SenderRecoveryStage`] /// - [`ExecutionStage`] @@ -91,20 +91,21 @@ impl DefaultStages { header_downloader: H, body_downloader: B, executor_factory: EF, - ) -> Self + ) -> Result where EF: ExecutorFactory, { - Self { + Ok(Self { online: OnlineStages::new( provider, header_mode, consensus, header_downloader, body_downloader, + Arc::new(TempDir::new()?), ), executor_factory, - } + }) } } @@ -150,6 +151,8 @@ pub struct OnlineStages { header_downloader: H, /// The block body downloader body_downloader: B, + /// Temporary directory for ETL usage on headers stage. + temp_dir: Arc, } impl OnlineStages { @@ -160,8 +163,9 @@ impl OnlineStages { consensus: Arc, header_downloader: H, body_downloader: B, + temp_dir: Arc, ) -> Self { - Self { provider, header_mode, consensus, header_downloader, body_downloader } + Self { provider, header_mode, consensus, header_downloader, body_downloader, temp_dir } } } @@ -175,12 +179,8 @@ where pub fn builder_with_headers( headers: HeaderStage, body_downloader: B, - consensus: Arc, ) -> StageSetBuilder { - StageSetBuilder::default() - .add_stage(headers) - .add_stage(TotalDifficultyStage::new(consensus.clone())) - .add_stage(BodyStage::new(body_downloader)) + StageSetBuilder::default().add_stage(headers).add_stage(BodyStage::new(body_downloader)) } /// Create a new builder using the given bodies stage. @@ -190,10 +190,16 @@ where mode: HeaderSyncMode, header_downloader: H, consensus: Arc, + temp_dir: Arc, ) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(provider, header_downloader, mode)) - .add_stage(TotalDifficultyStage::new(consensus.clone())) + .add_stage(HeaderStage::new( + provider, + header_downloader, + mode, + consensus.clone(), + temp_dir.clone(), + )) .add_stage(bodies) } } @@ -207,8 +213,13 @@ where { fn builder(self) -> StageSetBuilder { StageSetBuilder::default() - .add_stage(HeaderStage::new(self.provider, self.header_downloader, self.header_mode)) - .add_stage(TotalDifficultyStage::new(self.consensus.clone())) + .add_stage(HeaderStage::new( + self.provider, + self.header_downloader, + self.header_mode, + self.consensus.clone(), + self.temp_dir.clone(), + )) .add_stage(BodyStage::new(self.body_downloader)) } } diff --git a/crates/stages/src/stages/headers.rs b/crates/stages/src/stages/headers.rs index b57fcd279df9..c200e3a6236b 100644 --- a/crates/stages/src/stages/headers.rs +++ b/crates/stages/src/stages/headers.rs @@ -1,12 +1,16 @@ -use crate::{ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; +use crate::{BlockErrorKind, ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; use futures_util::StreamExt; +use reth_codecs::Compact; use reth_db::{ cursor::{DbCursorRO, DbCursorRW}, database::Database, tables, transaction::{DbTx, DbTxMut}, + RawKey, RawTable, RawValue, }; +use reth_etl::Collector; use reth_interfaces::{ + consensus::Consensus, p2p::headers::{downloader::HeaderDownloader, error::HeadersDownloaderError}, provider::ProviderError, }; @@ -14,10 +18,14 @@ use reth_primitives::{ stage::{ CheckpointBlockRange, EntitiesCheckpoint, HeadersCheckpoint, StageCheckpoint, StageId, }, - BlockHashOrNumber, BlockNumber, SealedHeader, + BlockHash, BlockNumber, SealedHeader, U256, }; use reth_provider::{DatabaseProviderRW, HeaderSyncGap, HeaderSyncGapProvider, HeaderSyncMode}; -use std::task::{ready, Context, Poll}; +use std::{ + sync::Arc, + task::{ready, Context, Poll}, +}; +use tempfile::TempDir; use tracing::*; /// The headers stage. @@ -41,10 +49,16 @@ pub struct HeaderStage { downloader: Downloader, /// The sync mode for the stage. mode: HeaderSyncMode, + /// Consensus client implementation + consensus: Arc, /// Current sync gap. sync_gap: Option, - /// Header buffer. - buffer: Option>, + /// ETL collector with HeaderHash -> BlockNumber + hash_collector: Collector, + /// ETL collector with BlockNumber -> SealedHeader + header_collector: Collector, + /// Returns true if the ETL collector has all necessary headers to fill the gap. + is_etl_ready: bool, } // === impl HeaderStage === @@ -54,56 +68,133 @@ where Downloader: HeaderDownloader, { /// Create a new header stage - pub fn new(database: Provider, downloader: Downloader, mode: HeaderSyncMode) -> Self { - Self { provider: database, downloader, mode, sync_gap: None, buffer: None } - } - - fn is_stage_done( - &self, - tx: &::TXMut, - checkpoint: u64, - ) -> Result { - let mut header_cursor = tx.cursor_read::()?; - let (head_num, _) = header_cursor - .seek_exact(checkpoint)? - .ok_or_else(|| ProviderError::HeaderNotFound(checkpoint.into()))?; - // Check if the next entry is congruent - Ok(header_cursor.next()?.map(|(next_num, _)| head_num + 1 == next_num).unwrap_or_default()) + pub fn new( + database: Provider, + downloader: Downloader, + mode: HeaderSyncMode, + consensus: Arc, + tempdir: Arc, + ) -> Self { + Self { + provider: database, + downloader, + mode, + consensus, + sync_gap: None, + hash_collector: Collector::new(tempdir.clone(), 100 * (1024 * 1024)), + header_collector: Collector::new(tempdir, 100 * (1024 * 1024)), + is_etl_ready: false, + } } - /// Write downloaded headers to the given transaction + /// Write downloaded headers to the given transaction from ETL. /// - /// Note: this writes the headers with rising block numbers. + /// Writes to the following tables: + /// [`tables::Headers`], [`tables::CanonicalHeaders`], [`tables::HeaderTD`] and + /// [`tables::HeaderNumbers`]. fn write_headers( - &self, + &mut self, tx: &::TXMut, - headers: Vec, - ) -> Result, StageError> { - trace!(target: "sync::stages::headers", len = headers.len(), "writing headers"); + ) -> Result { + let total_headers = self.header_collector.len(); + + info!(target: "sync::stages::headers", total = total_headers, "Writing headers"); - let mut cursor_header = tx.cursor_write::()?; - let mut cursor_canonical = tx.cursor_write::()?; + let mut cursor_header = tx.cursor_write::>()?; + let mut cursor_canonical = tx.cursor_write::>()?; + let mut cursor_td = tx.cursor_write::()?; - let mut latest = None; - // Since the headers were returned in descending order, - // iterate them in the reverse order - for header in headers.into_iter().rev() { + let mut last_header_number = tx + .cursor_read::()? + .last()? + .map(|(_, header)| header.number) + .unwrap_or_default(); + + // Find the latest total difficulty + let mut td: U256 = cursor_td + .seek_exact(last_header_number)? + .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))? + .1 + .into(); + + // Although headers were downloaded in reverse order, the collector iterates it in ascending + // order + + let interval = total_headers / 10; + for (index, header) in self.header_collector.iter()?.enumerate() { + let (number, header_buf) = header?; + + if index > 0 && index % interval == 0 { + info!(target: "sync::stages::headers", progress = %format!("{:.2}%", (index as f64 / total_headers as f64) * 100.0), "Writing headers"); + } + + let (sealed_header, _) = SealedHeader::from_compact(&header_buf, header_buf.len()); + let (header, header_hash) = sealed_header.split(); if header.number == 0 { continue } + last_header_number = header.number; + + // Increase total difficulty + td += header.difficulty; + + // Header validation + self.consensus.validate_header_with_total_difficulty(&header, td).map_err(|error| { + StageError::Block { + block: Box::new(header.clone().seal(header_hash)), + error: BlockErrorKind::Validation(error), + } + })?; + + // Append to HeaderTD + cursor_td.append(header.number, td.into())?; + + // Append to CanonicalHeaders + cursor_canonical + .append(RawKey::::from_vec(number.clone()), header_hash.into())?; + + // Append to Headers + cursor_header.append(RawKey::::from_vec(number), header.into())?; + } + + info!(target: "sync::stages::headers", total = total_headers, "Writing header hash index"); + + let mut cursor_header_numbers = tx.cursor_write::>()?; + let mut first_sync = false; + + // If we only have the genesis block hash, then we are at first sync, and we can remove it, + // add it to the collector and use tx.append on all hashes. + if let Some((hash, block_number)) = cursor_header_numbers.last()? { + if block_number.value()? == 0 { + self.hash_collector.insert(hash.key()?, 0); + cursor_header_numbers.delete_current()?; + first_sync = true; + } + } - let header_hash = header.hash(); - let header_number = header.number; - let header = header.unseal(); - latest = Some(header.number); + // Since ETL sorts all entries by hashes, we are either appending (first sync) or inserting + // in order (further syncs). + for (index, hash_to_number) in self.hash_collector.iter()?.enumerate() { + let (hash, number) = hash_to_number?; - // NOTE: HeaderNumbers are not sorted and can't be inserted with cursor. - tx.put::(header_hash, header_number)?; - cursor_header.insert(header_number, header)?; - cursor_canonical.insert(header_number, header_hash)?; + if index > 0 && index % interval == 0 { + info!(target: "sync::stages::headers", progress = ((index as f64 / total_headers as f64) * 100.0).round(), "Writing headers hash index"); + } + + if first_sync { + cursor_header_numbers.append( + RawKey::::from_vec(hash), + RawValue::::from_vec(number), + )?; + } else { + cursor_header_numbers.insert( + RawKey::::from_vec(hash), + RawValue::::from_vec(number), + )?; + } } - Ok(latest) + Ok(last_header_number) } } @@ -125,14 +216,8 @@ where ) -> Poll> { let current_checkpoint = input.checkpoint(); - // Return if buffer already has some items. - if self.buffer.is_some() { - // TODO: review - trace!( - target: "sync::stages::headers", - checkpoint = %current_checkpoint.block_number, - "Buffer is not empty" - ); + // Return if stage has already completed the gap on the ETL files + if self.is_etl_ready { return Poll::Ready(Ok(())) } @@ -149,27 +234,42 @@ where target = ?tip, "Target block already reached" ); + self.is_etl_ready = true; return Poll::Ready(Ok(())) } debug!(target: "sync::stages::headers", ?tip, head = ?gap.local_head.hash(), "Commencing sync"); + let local_head_number = gap.local_head.number; // let the downloader know what to sync - self.downloader.update_sync_gap(gap.local_head, gap.target); - - let result = match ready!(self.downloader.poll_next_unpin(cx)) { - Some(Ok(headers)) => { - info!(target: "sync::stages::headers", len = headers.len(), "Received headers"); - self.buffer = Some(headers); - Ok(()) - } - Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => { - error!(target: "sync::stages::headers", ?error, "Cannot attach header to head"); - Err(StageError::DetachedHead { local_head, header, error }) + self.downloader.update_sync_gap(gap.local_head, gap.target.clone()); + + // We only want to stop once we have all the headers on ETL filespace (disk). + loop { + match ready!(self.downloader.poll_next_unpin(cx)) { + Some(Ok(headers)) => { + info!(target: "sync::stages::headers", total = headers.len(), from_block = headers.first().map(|h| h.number), to_block = headers.last().map(|h| h.number), "Received headers"); + for header in headers { + let header_number = header.number; + + self.hash_collector.insert(header.hash, header_number); + self.header_collector.insert(header_number, header); + + // Headers are downloaded in reverse, so if we reach here, we know we have + // filled the gap. + if header_number == local_head_number + 1 { + self.is_etl_ready = true; + return Poll::Ready(Ok(())) + } + } + } + Some(Err(HeadersDownloaderError::DetachedHead { local_head, header, error })) => { + error!(target: "sync::stages::headers", ?error, "Cannot attach header to head"); + return Poll::Ready(Err(StageError::DetachedHead { local_head, header, error })) + } + None => return Poll::Ready(Err(StageError::ChannelClosed)), } - None => Err(StageError::ChannelClosed), - }; - Poll::Ready(result) + } } /// Download the headers in reverse order (falling block numbers) @@ -181,99 +281,40 @@ where ) -> Result { let current_checkpoint = input.checkpoint(); - let gap = self.sync_gap.clone().ok_or(StageError::MissingSyncGap)?; - if gap.is_closed() { - return Ok(ExecOutput::done(current_checkpoint)) + if self.sync_gap.as_ref().ok_or(StageError::MissingSyncGap)?.is_closed() { + self.is_etl_ready = false; + return Ok(ExecOutput::done(current_checkpoint)); } - let local_head = gap.local_head.number; - let tip = gap.target.tip(); + // We should be here only after we have downloaded all headers into the disk buffer (ETL). + if !self.is_etl_ready { + return Err(StageError::MissingDownloadBuffer) + } - let downloaded_headers = self.buffer.take().ok_or(StageError::MissingDownloadBuffer)?; - let tip_block_number = match tip { - // If tip is hash and it equals to the first downloaded header's hash, we can use - // the block number of this header as tip. - BlockHashOrNumber::Hash(hash) => downloaded_headers - .first() - .and_then(|header| (header.hash == hash).then_some(header.number)), - // If tip is number, we can just grab it and not resolve using downloaded headers. - BlockHashOrNumber::Number(number) => Some(number), - }; + // Reset flag + self.is_etl_ready = false; - // Since we're syncing headers in batches, gap tip will move in reverse direction towards - // our local head with every iteration. To get the actual target block number we're - // syncing towards, we need to take into account already synced headers from the database. - // It is `None`, if tip didn't change and we're still downloading headers for previously - // calculated gap. - let tx = provider.tx_ref(); - let target_block_number = if let Some(tip_block_number) = tip_block_number { - let local_max_block_number = tx - .cursor_read::()? - .last()? - .map(|(canonical_block, _)| canonical_block); - - Some(tip_block_number.max(local_max_block_number.unwrap_or_default())) - } else { - None - }; + // Write the headers and related tables to DB from ETL space + let to_be_processed = self.hash_collector.len() as u64; + let last_header_number = self.write_headers::(provider.tx_ref())?; - let mut stage_checkpoint = match current_checkpoint.headers_stage_checkpoint() { - // If checkpoint block range matches our range, we take the previously used - // stage checkpoint as-is. - Some(stage_checkpoint) - if stage_checkpoint.block_range.from == input.checkpoint().block_number => - { - stage_checkpoint - } - // Otherwise, we're on the first iteration of new gap sync, so we recalculate the number - // of already processed and total headers. - // `target_block_number` is guaranteed to be `Some`, because on the first iteration - // we download the header for missing tip and use its block number. - _ => { - let target = target_block_number.expect("No downloaded header for tip found"); + Ok(ExecOutput { + checkpoint: StageCheckpoint::new(last_header_number).with_headers_stage_checkpoint( HeadersCheckpoint { block_range: CheckpointBlockRange { from: input.checkpoint().block_number, - to: target, + to: last_header_number, }, progress: EntitiesCheckpoint { - // Set processed to the local head block number + number - // of block already filled in the gap. - processed: local_head + (target - tip_block_number.unwrap_or_default()), - total: target, + processed: input.checkpoint().block_number + to_be_processed, + total: last_header_number, }, - } - } - }; - - // Total headers can be updated if we received new tip from the network, and need to fill - // the local gap. - if let Some(target_block_number) = target_block_number { - stage_checkpoint.progress.total = target_block_number; - } - stage_checkpoint.progress.processed += downloaded_headers.len() as u64; - - // Write the headers to db - self.write_headers::(tx, downloaded_headers)?.unwrap_or_default(); - - if self.is_stage_done::(tx, current_checkpoint.block_number)? { - let checkpoint = current_checkpoint.block_number.max( - tx.cursor_read::()? - .last()? - .map(|(num, _)| num) - .unwrap_or_default(), - ); - Ok(ExecOutput { - checkpoint: StageCheckpoint::new(checkpoint) - .with_headers_stage_checkpoint(stage_checkpoint), - done: true, - }) - } else { - Ok(ExecOutput { - checkpoint: current_checkpoint.with_headers_stage_checkpoint(stage_checkpoint), - done: false, - }) - } + }, + ), + // We only reach here if all headers have been downloaded by ETL, and pushed to DB all + // in one stage run. + done: true, + }) } /// Unwind the stage. @@ -282,7 +323,6 @@ where provider: &DatabaseProviderRW, input: UnwindInput, ) -> Result { - self.buffer.take(); self.sync_gap.take(); provider.unwind_table_by_walker::( @@ -291,6 +331,8 @@ where provider.unwind_table_by_num::(input.unwind_to)?; let unwound_headers = provider.unwind_table_by_num::(input.unwind_to)?; + provider.unwind_table_by_num::(input.unwind_to)?; + let stage_checkpoint = input.checkpoint.headers_stage_checkpoint().map(|stage_checkpoint| HeadersCheckpoint { block_range: stage_checkpoint.block_range, @@ -338,6 +380,7 @@ mod tests { use reth_primitives::U256; use reth_provider::{BlockHashReader, BlockNumReader, HeaderProvider}; use std::sync::Arc; + use tempfile::TempDir; use tokio::sync::watch; pub(crate) struct HeadersTestRunner { @@ -345,6 +388,7 @@ mod tests { channel: (watch::Sender, watch::Receiver), downloader_factory: Box D + Send + Sync + 'static>, db: TestStageDB, + consensus: Arc, } impl Default for HeadersTestRunner { @@ -353,6 +397,7 @@ mod tests { Self { client: client.clone(), channel: watch::channel(B256::ZERO), + consensus: Arc::new(TestConsensus::default()), downloader_factory: Box::new(move || { TestHeaderDownloader::new( client.clone(), @@ -378,6 +423,8 @@ mod tests { self.db.factory.clone(), (*self.downloader_factory)(), HeaderSyncMode::Tip(self.channel.1.clone()), + self.consensus.clone(), + Arc::new(TempDir::new().unwrap()), ) } } @@ -390,10 +437,7 @@ mod tests { let mut rng = generators::rng(); let start = input.checkpoint().block_number; let head = random_header(&mut rng, start, None); - self.db.insert_headers(std::iter::once(&head))?; - // patch td table for `update_head` call - self.db - .commit(|tx| Ok(tx.put::(head.number, U256::ZERO.into())?))?; + self.db.insert_headers_with_td(std::iter::once(&head))?; // use previous checkpoint as seed size let end = input.target.unwrap_or_default() + 1; @@ -417,8 +461,9 @@ mod tests { match output { Some(output) if output.checkpoint.block_number > initial_checkpoint => { let provider = self.db.factory.provider()?; - for block_num in (initial_checkpoint..output.checkpoint.block_number).rev() - { + let mut td = U256::ZERO; + + for block_num in initial_checkpoint..output.checkpoint.block_number { // look up the header hash let hash = provider.block_hash(block_num)?.expect("no header hash"); @@ -430,6 +475,13 @@ mod tests { assert!(header.is_some()); let header = header.unwrap().seal_slow(); assert_eq!(header.hash(), hash); + + // validate the header total difficulty + td += header.difficulty; + assert_eq!( + provider.header_td_by_number(block_num)?.map(Into::into), + Some(td) + ); } } _ => self.check_no_header_entry_above(initial_checkpoint)?, @@ -469,6 +521,7 @@ mod tests { .build(client.clone(), Arc::new(TestConsensus::default())) }), db: TestStageDB::default(), + consensus: Arc::new(TestConsensus::default()), } } } @@ -482,6 +535,7 @@ mod tests { .ensure_no_entry_above_by_value::(block, |val| val)?; self.db.ensure_no_entry_above::(block, |key| key)?; self.db.ensure_no_entry_above::(block, |key| key)?; + self.db.ensure_no_entry_above::(block, |num| num)?; Ok(()) } @@ -527,69 +581,10 @@ mod tests { }, done: true }) if block_number == tip.number && from == checkpoint && to == previous_stage && // -1 because we don't need to download the local head - processed == checkpoint + headers.len() as u64 - 1 && total == tip.number); - assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); - } - - /// Execute the stage in two steps - #[tokio::test] - async fn execute_from_previous_checkpoint() { - let mut runner = HeadersTestRunner::with_linear_downloader(); - // pick range that's larger than the configured headers batch size - let (checkpoint, previous_stage) = (600, 1200); - let mut input = ExecInput { - target: Some(previous_stage), - checkpoint: Some(StageCheckpoint::new(checkpoint)), - }; - let headers = runner.seed_execution(input).expect("failed to seed execution"); - let rx = runner.execute(input); - - runner.client.extend(headers.iter().rev().map(|h| h.clone().unseal())).await; - - // skip `after_execution` hook for linear downloader - let tip = headers.last().unwrap(); - runner.send_tip(tip.hash()); - - let result = rx.await.unwrap(); - assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint { - block_number, - stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint { - block_range: CheckpointBlockRange { - from, - to - }, - progress: EntitiesCheckpoint { - processed, - total, - } - })) - }, done: false }) if block_number == checkpoint && - from == checkpoint && to == previous_stage && - processed == checkpoint + 500 && total == tip.number); - - runner.client.clear().await; - runner.client.extend(headers.iter().take(101).map(|h| h.clone().unseal()).rev()).await; - input.checkpoint = Some(result.unwrap().checkpoint); - - let rx = runner.execute(input); - let result = rx.await.unwrap(); - - assert_matches!(result, Ok(ExecOutput { checkpoint: StageCheckpoint { - block_number, - stage_checkpoint: Some(StageUnitCheckpoint::Headers(HeadersCheckpoint { - block_range: CheckpointBlockRange { - from, - to - }, - progress: EntitiesCheckpoint { - processed, - total, - } - })) - }, done: true }) if block_number == tip.number && - from == checkpoint && to == previous_stage && - // -1 because we don't need to download the local head - processed == checkpoint + headers.len() as u64 - 1 && total == tip.number); + processed == checkpoint + headers.len() as u64 - 1 && total == tip.number + // +1 because of the seeded execution that inserts the first block + && previous_stage - checkpoint + 1 == runner.db().table::().unwrap().len() as u64 + ); assert!(runner.validate_execution(input, result.ok()).is_ok(), "validation failed"); } } diff --git a/crates/stages/src/stages/mod.rs b/crates/stages/src/stages/mod.rs index 29e6a1c63450..ff36d1020700 100644 --- a/crates/stages/src/stages/mod.rs +++ b/crates/stages/src/stages/mod.rs @@ -18,8 +18,6 @@ mod index_storage_history; mod merkle; /// The sender recovery stage. mod sender_recovery; -/// The total difficulty stage -mod total_difficulty; /// The transaction lookup stage mod tx_lookup; @@ -33,7 +31,6 @@ pub use index_account_history::*; pub use index_storage_history::*; pub use merkle::*; pub use sender_recovery::*; -pub use total_difficulty::*; pub use tx_lookup::*; #[cfg(test)] diff --git a/crates/stages/src/stages/total_difficulty.rs b/crates/stages/src/stages/total_difficulty.rs deleted file mode 100644 index d523cf4ce850..000000000000 --- a/crates/stages/src/stages/total_difficulty.rs +++ /dev/null @@ -1,313 +0,0 @@ -use crate::{BlockErrorKind, ExecInput, ExecOutput, Stage, StageError, UnwindInput, UnwindOutput}; -use reth_db::{ - cursor::{DbCursorRO, DbCursorRW}, - database::Database, - tables, - transaction::{DbTx, DbTxMut}, - DatabaseError, -}; -use reth_interfaces::{consensus::Consensus, provider::ProviderError}; -use reth_primitives::{ - stage::{EntitiesCheckpoint, StageCheckpoint, StageId}, - U256, -}; -use reth_provider::DatabaseProviderRW; -use std::sync::Arc; -use tracing::*; - -/// The total difficulty stage. -/// -/// This stage walks over inserted headers and computes total difficulty -/// at each block. The entries are inserted into [`HeaderTD`][reth_db::tables::HeaderTD] -/// table. -#[derive(Debug, Clone)] -pub struct TotalDifficultyStage { - /// Consensus client implementation - consensus: Arc, - /// The number of table entries to commit at once - commit_threshold: u64, -} - -impl TotalDifficultyStage { - /// Create a new total difficulty stage - pub fn new(consensus: Arc) -> Self { - Self { consensus, commit_threshold: 100_000 } - } - - /// Set a commit threshold on total difficulty stage - pub fn with_commit_threshold(mut self, commit_threshold: u64) -> Self { - self.commit_threshold = commit_threshold; - self - } -} - -impl Stage for TotalDifficultyStage { - /// Return the id of the stage - fn id(&self) -> StageId { - StageId::TotalDifficulty - } - - /// Write total difficulty entries - fn execute( - &mut self, - provider: &DatabaseProviderRW, - input: ExecInput, - ) -> Result { - let tx = provider.tx_ref(); - if input.target_reached() { - return Ok(ExecOutput::done(input.checkpoint())) - } - - let (range, is_final_range) = input.next_block_range_with_threshold(self.commit_threshold); - let (start_block, end_block) = range.clone().into_inner(); - - debug!(target: "sync::stages::total_difficulty", start_block, end_block, "Commencing sync"); - - // Acquire cursor over total difficulty and headers tables - let mut cursor_td = tx.cursor_write::()?; - let mut cursor_headers = tx.cursor_read::()?; - - // Get latest total difficulty - let last_header_number = input.checkpoint().block_number; - let last_entry = cursor_td - .seek_exact(last_header_number)? - .ok_or(ProviderError::TotalDifficultyNotFound(last_header_number))?; - - let mut td: U256 = last_entry.1.into(); - debug!(target: "sync::stages::total_difficulty", ?td, block_number = last_header_number, "Last total difficulty entry"); - - // Walk over newly inserted headers, update & insert td - for entry in cursor_headers.walk_range(range)? { - let (block_number, header) = entry?; - td += header.difficulty; - - self.consensus.validate_header_with_total_difficulty(&header, td).map_err(|error| { - StageError::Block { - block: Box::new(header.seal_slow()), - error: BlockErrorKind::Validation(error), - } - })?; - cursor_td.append(block_number, td.into())?; - } - - Ok(ExecOutput { - checkpoint: StageCheckpoint::new(end_block) - .with_entities_stage_checkpoint(stage_checkpoint(provider)?), - done: is_final_range, - }) - } - - /// Unwind the stage. - fn unwind( - &mut self, - provider: &DatabaseProviderRW, - input: UnwindInput, - ) -> Result { - let (_, unwind_to, _) = input.unwind_block_range_with_threshold(self.commit_threshold); - - provider.unwind_table_by_num::(unwind_to)?; - - Ok(UnwindOutput { - checkpoint: StageCheckpoint::new(unwind_to) - .with_entities_stage_checkpoint(stage_checkpoint(provider)?), - }) - } -} - -fn stage_checkpoint( - provider: &DatabaseProviderRW, -) -> Result { - Ok(EntitiesCheckpoint { - processed: provider.tx_ref().entries::()? as u64, - total: provider.tx_ref().entries::()? as u64, - }) -} - -#[cfg(test)] -mod tests { - use assert_matches::assert_matches; - use reth_db::transaction::DbTx; - use reth_interfaces::test_utils::{ - generators, - generators::{random_header, random_header_range}, - TestConsensus, - }; - use reth_primitives::{stage::StageUnitCheckpoint, BlockNumber, SealedHeader}; - use reth_provider::HeaderProvider; - - use super::*; - use crate::test_utils::{ - stage_test_suite_ext, ExecuteStageTestRunner, StageTestRunner, TestRunnerError, - TestStageDB, UnwindStageTestRunner, - }; - - stage_test_suite_ext!(TotalDifficultyTestRunner, total_difficulty); - - #[tokio::test] - async fn execute_with_intermediate_commit() { - let threshold = 50; - let (stage_progress, previous_stage) = (1000, 1100); // input exceeds threshold - - let mut runner = TotalDifficultyTestRunner::default(); - runner.set_threshold(threshold); - - let first_input = ExecInput { - target: Some(previous_stage), - checkpoint: Some(StageCheckpoint::new(stage_progress)), - }; - - // Seed only once with full input range - runner.seed_execution(first_input).expect("failed to seed execution"); - - // Execute first time - let result = runner.execute(first_input).await.unwrap(); - let expected_progress = stage_progress + threshold; - assert_matches!( - result, - Ok(ExecOutput { checkpoint: StageCheckpoint { - block_number, - stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint { - processed, - total - })) - }, done: false }) if block_number == expected_progress && processed == 1 + threshold && - total == runner.db.table::().unwrap().len() as u64 - ); - - // Execute second time - let second_input = ExecInput { - target: Some(previous_stage), - checkpoint: Some(StageCheckpoint::new(expected_progress)), - }; - let result = runner.execute(second_input).await.unwrap(); - assert_matches!( - result, - Ok(ExecOutput { checkpoint: StageCheckpoint { - block_number, - stage_checkpoint: Some(StageUnitCheckpoint::Entities(EntitiesCheckpoint { - processed, - total - })) - }, done: true }) if block_number == previous_stage && processed == total && - total == runner.db.table::().unwrap().len() as u64 - ); - - assert!(runner.validate_execution(first_input, result.ok()).is_ok(), "validation failed"); - } - - struct TotalDifficultyTestRunner { - db: TestStageDB, - consensus: Arc, - commit_threshold: u64, - } - - impl Default for TotalDifficultyTestRunner { - fn default() -> Self { - Self { - db: Default::default(), - consensus: Arc::new(TestConsensus::default()), - commit_threshold: 500, - } - } - } - - impl StageTestRunner for TotalDifficultyTestRunner { - type S = TotalDifficultyStage; - - fn db(&self) -> &TestStageDB { - &self.db - } - - fn stage(&self) -> Self::S { - TotalDifficultyStage { - consensus: self.consensus.clone(), - commit_threshold: self.commit_threshold, - } - } - } - - #[async_trait::async_trait] - impl ExecuteStageTestRunner for TotalDifficultyTestRunner { - type Seed = Vec; - - fn seed_execution(&mut self, input: ExecInput) -> Result { - let mut rng = generators::rng(); - let start = input.checkpoint().block_number; - let head = random_header(&mut rng, start, None); - self.db.insert_headers(std::iter::once(&head))?; - self.db.commit(|tx| { - let td: U256 = tx - .cursor_read::()? - .last()? - .map(|(_, v)| v) - .unwrap_or_default() - .into(); - tx.put::(head.number, (td + head.difficulty).into())?; - Ok(()) - })?; - - // use previous progress as seed size - let end = input.target.unwrap_or_default() + 1; - - if start + 1 >= end { - return Ok(Vec::default()) - } - - let mut headers = random_header_range(&mut rng, start + 1..end, head.hash()); - self.db.insert_headers(headers.iter())?; - headers.insert(0, head); - Ok(headers) - } - - /// Validate stored headers - fn validate_execution( - &self, - input: ExecInput, - output: Option, - ) -> Result<(), TestRunnerError> { - let initial_stage_progress = input.checkpoint().block_number; - match output { - Some(output) if output.checkpoint.block_number > initial_stage_progress => { - let provider = self.db.factory.provider()?; - - let mut header_cursor = provider.tx_ref().cursor_read::()?; - let (_, mut current_header) = header_cursor - .seek_exact(initial_stage_progress)? - .expect("no initial header"); - let mut td: U256 = provider - .header_td_by_number(initial_stage_progress)? - .expect("no initial td"); - - while let Some((next_key, next_header)) = header_cursor.next()? { - assert_eq!(current_header.number + 1, next_header.number); - td += next_header.difficulty; - assert_eq!( - provider.header_td_by_number(next_key)?.map(Into::into), - Some(td) - ); - current_header = next_header; - } - } - _ => self.check_no_td_above(initial_stage_progress)?, - }; - Ok(()) - } - } - - impl UnwindStageTestRunner for TotalDifficultyTestRunner { - fn validate_unwind(&self, input: UnwindInput) -> Result<(), TestRunnerError> { - self.check_no_td_above(input.unwind_to) - } - } - - impl TotalDifficultyTestRunner { - fn check_no_td_above(&self, block: BlockNumber) -> Result<(), TestRunnerError> { - self.db.ensure_no_entry_above::(block, |num| num)?; - Ok(()) - } - - fn set_threshold(&mut self, new_threshold: u64) { - self.commit_threshold = new_threshold; - } - } -} diff --git a/crates/storage/codecs/derive/src/compact/generator.rs b/crates/storage/codecs/derive/src/compact/generator.rs index 370d74eec2a1..8cd9070bb4b2 100644 --- a/crates/storage/codecs/derive/src/compact/generator.rs +++ b/crates/storage/codecs/derive/src/compact/generator.rs @@ -52,7 +52,7 @@ pub fn generate_from_to(ident: &Ident, fields: &FieldList, is_zstd: bool) -> Tok /// Generates code to implement the `Compact` trait method `to_compact`. fn generate_from_compact(fields: &FieldList, ident: &Ident, is_zstd: bool) -> TokenStream2 { let mut lines = vec![]; - let mut known_types = vec!["B256", "Address", "Bloom", "Vec", "TxHash"]; + let mut known_types = vec!["B256", "Address", "Bloom", "Vec", "TxHash", "BlockHash"]; // Only types without `Bytes` should be added here. It's currently manually added, since // it's hard to figure out with derive_macro which types have Bytes fields. diff --git a/crates/storage/codecs/derive/src/compact/mod.rs b/crates/storage/codecs/derive/src/compact/mod.rs index 41b4ccfc55ac..167a7ff6a538 100644 --- a/crates/storage/codecs/derive/src/compact/mod.rs +++ b/crates/storage/codecs/derive/src/compact/mod.rs @@ -143,7 +143,7 @@ fn should_use_alt_impl(ftype: &String, segment: &syn::PathSegment) -> bool { if let (Some(path), 1) = (arg_path.path.segments.first(), arg_path.path.segments.len()) { - if ["B256", "Address", "Address", "Bloom", "TxHash"] + if ["B256", "Address", "Address", "Bloom", "TxHash", "BlockHash"] .contains(&path.ident.to_string().as_str()) { return true diff --git a/crates/storage/db/src/tables/codecs/compact.rs b/crates/storage/db/src/tables/codecs/compact.rs index f31e61026e65..38722eb49036 100644 --- a/crates/storage/db/src/tables/codecs/compact.rs +++ b/crates/storage/db/src/tables/codecs/compact.rs @@ -29,6 +29,7 @@ macro_rules! impl_compression_for_compact { } impl_compression_for_compact!( + SealedHeader, Header, Account, Log, diff --git a/crates/storage/db/src/tables/raw.rs b/crates/storage/db/src/tables/raw.rs index 46dffa1db58f..b1932f152c26 100644 --- a/crates/storage/db/src/tables/raw.rs +++ b/crates/storage/db/src/tables/raw.rs @@ -55,6 +55,12 @@ impl RawKey { Self { key: K::encode(key).into(), _phantom: std::marker::PhantomData } } + /// Creates a raw key from an existing `Vec`. Useful when we already have the encoded + /// key. + pub fn from_vec(vec: Vec) -> Self { + Self { key: vec, _phantom: std::marker::PhantomData } + } + /// Returns the decoded value. pub fn key(&self) -> Result { K::decode(&self.key) @@ -114,6 +120,12 @@ impl RawValue { Self { value: V::compress(value).into(), _phantom: std::marker::PhantomData } } + /// Creates a raw value from an existing `Vec`. Useful when we already have the compressed + /// value. + pub fn from_vec(vec: Vec) -> Self { + Self { value: vec, _phantom: std::marker::PhantomData } + } + /// Returns the decompressed value. pub fn value(&self) -> Result { V::decompress(&self.value)