diff --git a/crates/anvil/src/cmd.rs b/crates/anvil/src/cmd.rs index 533272f44675..96c4d86de815 100644 --- a/crates/anvil/src/cmd.rs +++ b/crates/anvil/src/cmd.rs @@ -188,6 +188,7 @@ impl NodeArgs { .fork_block_number .or_else(|| self.evm_opts.fork_url.as_ref().and_then(|f| f.block)), ) + .with_fork_headers(self.evm_opts.fork_headers) .with_fork_chain_id(self.evm_opts.fork_chain_id.map(u64::from)) .fork_request_timeout(self.evm_opts.fork_request_timeout.map(Duration::from_millis)) .fork_request_retries(self.evm_opts.fork_request_retries) @@ -328,6 +329,17 @@ pub struct AnvilEvmArgs { )] pub fork_url: Option, + /// Headers to use for the rpc client, e.g. "User-Agent: test-agent" + /// + /// See --fork-url. + #[clap( + long = "fork-header", + value_name = "HEADERS", + help_heading = "Fork config", + requires = "fork_url" + )] + pub fork_headers: Vec, + /// Timeout in ms for requests sent to remote JSON-RPC server in forking mode. /// /// Default value 45000 @@ -664,6 +676,23 @@ mod tests { assert_eq!(args.hardfork, Some(Hardfork::Berlin)); } + #[test] + fn can_parse_fork_headers() { + let args: NodeArgs = NodeArgs::parse_from([ + "anvil", + "--fork-url", + "http,://localhost:8545", + "--fork-header", + "User-Agent: test-agent", + "--fork-header", + "Referrer: example.com", + ]); + assert_eq!( + args.evm_opts.fork_headers, + vec!["User-Agent: test-agent", "Referrer: example.com"] + ); + } + #[test] fn can_parse_prune_config() { let args: NodeArgs = NodeArgs::parse_from(["anvil", "--prune-history"]); diff --git a/crates/anvil/src/config.rs b/crates/anvil/src/config.rs index b9901d469ba2..4a38e291f5f8 100644 --- a/crates/anvil/src/config.rs +++ b/crates/anvil/src/config.rs @@ -123,6 +123,8 @@ pub struct NodeConfig { pub eth_rpc_url: Option, /// pins the block number for the state fork pub fork_block_number: Option, + /// headers to use with `eth_rpc_url` + pub fork_headers: Vec, /// specifies chain id for cache to skip fetching from remote in offline-start mode pub fork_chain_id: Option, /// The generator used to generate the dev accounts @@ -392,6 +394,7 @@ impl Default for NodeConfig { config_out: None, genesis: None, fork_request_timeout: REQUEST_TIMEOUT, + fork_headers: vec![], fork_request_retries: 5, fork_retry_backoff: Duration::from_millis(1_000), fork_chain_id: None, @@ -659,6 +662,13 @@ impl NodeConfig { self } + /// Sets the `fork_headers` to use with `eth_rpc_url` + #[must_use] + pub fn with_fork_headers(mut self, headers: Vec) -> Self { + self.fork_headers = headers; + self + } + /// Sets the `fork_request_timeout` to use for requests #[must_use] pub fn fork_request_timeout(mut self, fork_request_timeout: Option) -> Self { @@ -901,6 +911,7 @@ impl NodeConfig { .compute_units_per_second(self.compute_units_per_second) .max_retry(10) .initial_backoff(1000) + .headers(self.fork_headers.clone()) .build() .expect("Failed to establish provider to fork url"), ); diff --git a/crates/common/src/provider.rs b/crates/common/src/provider.rs index ecb313f4aa70..f00734721e45 100644 --- a/crates/common/src/provider.rs +++ b/crates/common/src/provider.rs @@ -1,6 +1,9 @@ //! Commonly used helpers to construct `Provider`s -use crate::{runtime_client::RuntimeClient, ALCHEMY_FREE_TIER_CUPS, REQUEST_TIMEOUT}; +use crate::{ + runtime_client::{RuntimeClient, RuntimeClientBuilder}, + ALCHEMY_FREE_TIER_CUPS, REQUEST_TIMEOUT, +}; use ethers_core::types::{Chain, U256}; use ethers_middleware::gas_oracle::{GasCategory, GasOracle, Polygon}; use ethers_providers::{is_local_endpoint, Middleware, Provider, DEFAULT_LOCAL_POLL_INTERVAL}; @@ -61,6 +64,7 @@ pub struct ProviderBuilder { compute_units_per_second: u64, /// JWT Secret jwt: Option, + headers: Vec, } // === impl ProviderBuilder === @@ -104,6 +108,7 @@ impl ProviderBuilder { // alchemy max cpus compute_units_per_second: ALCHEMY_FREE_TIER_CUPS, jwt: None, + headers: vec![], } } @@ -188,6 +193,13 @@ impl ProviderBuilder { self } + /// Sets http headers + pub fn headers(mut self, headers: Vec) -> Self { + self.headers = headers; + + self + } + /// Same as [`Self:build()`] but also retrieves the `chainId` in order to derive an appropriate /// interval. pub async fn connect(self) -> Result { @@ -211,20 +223,24 @@ impl ProviderBuilder { timeout, compute_units_per_second, jwt, + headers, } = self; let url = url?; - let is_local = is_local_endpoint(url.as_str()); - - let mut provider = Provider::new(RuntimeClient::new( - url, + let client_builder = RuntimeClientBuilder::new( + url.clone(), max_retry, timeout_retry, initial_backoff, timeout, compute_units_per_second, - jwt, - )); + ) + .with_headers(headers) + .with_jwt(jwt); + + let mut provider = Provider::new(client_builder.build()); + + let is_local = is_local_endpoint(url.as_str()); if is_local { provider = provider.interval(DEFAULT_LOCAL_POLL_INTERVAL); diff --git a/crates/common/src/runtime_client.rs b/crates/common/src/runtime_client.rs index d09f17d638c2..38fd9ddf9d0c 100644 --- a/crates/common/src/runtime_client.rs +++ b/crates/common/src/runtime_client.rs @@ -7,9 +7,12 @@ use ethers_providers::{ JsonRpcError, JwtAuth, JwtKey, ProviderError, PubsubClient, RetryClient, RetryClientBuilder, RpcError, Ws, }; -use reqwest::{header::HeaderValue, Url}; +use reqwest::{ + header::{HeaderName, HeaderValue}, + Url, +}; use serde::{de::DeserializeOwned, Serialize}; -use std::{fmt::Debug, path::PathBuf, sync::Arc, time::Duration}; +use std::{fmt::Debug, path::PathBuf, str::FromStr, sync::Arc, time::Duration}; use thiserror::Error; use tokio::sync::RwLock; @@ -39,6 +42,10 @@ pub enum RuntimeClientError { #[error("URL scheme is not supported: {0}")] BadScheme(String), + /// Invalid HTTP header + #[error("Invalid HTTP header: {0}")] + BadHeader(String), + /// Invalid file path #[error("Invalid IPC file path: {0}")] BadPath(String), @@ -82,6 +89,20 @@ pub struct RuntimeClient { /// available CUPS compute_units_per_second: u64, jwt: Option, + headers: Vec, +} + +/// Builder for RuntimeClient +pub struct RuntimeClientBuilder { + url: Url, + max_retry: u32, + timeout_retry: u32, + initial_backoff: u64, + timeout: Duration, + /// available CUPS + compute_units_per_second: u64, + jwt: Option, + headers: Vec, } impl ::core::fmt::Display for RuntimeClient { @@ -104,32 +125,11 @@ fn build_auth(jwt: String) -> eyre::Result { } impl RuntimeClient { - /// Creates a new dynamic provider from a URL - pub fn new( - url: Url, - max_retry: u32, - timeout_retry: u32, - initial_backoff: u64, - timeout: Duration, - compute_units_per_second: u64, - jwt: Option, - ) -> Self { - Self { - client: Arc::new(RwLock::new(None)), - url, - max_retry, - timeout_retry, - initial_backoff, - timeout, - compute_units_per_second, - jwt, - } - } - async fn connect(&self) -> Result { match self.url.scheme() { "http" | "https" => { let mut client_builder = reqwest::Client::builder().timeout(self.timeout); + let mut headers = reqwest::header::HeaderMap::new(); if let Some(jwt) = self.jwt.as_ref() { let auth = build_auth(jwt.clone()).map_err(|err| { @@ -142,16 +142,25 @@ impl RuntimeClient { .expect("Header should be valid string"); auth_value.set_sensitive(true); - let mut headers = reqwest::header::HeaderMap::new(); headers.insert(reqwest::header::AUTHORIZATION, auth_value); - - client_builder = client_builder.default_headers(headers); }; + for header in self.headers.iter() { + let make_err = || RuntimeClientError::BadHeader(header.to_string()); + + let (key, val) = header.split_once(':').ok_or_else(make_err)?; + + headers.insert( + HeaderName::from_str(key.trim()).map_err(|_| make_err())?, + HeaderValue::from_str(val.trim()).map_err(|_| make_err())?, + ); + } + + client_builder = client_builder.default_headers(headers); + let client = client_builder .build() .map_err(|e| RuntimeClientError::ProviderError(e.into()))?; - let provider = Http::new_with_client(self.url.clone(), client); #[allow(clippy::box_default)] @@ -190,6 +199,57 @@ impl RuntimeClient { } } +impl RuntimeClientBuilder { + /// Create new RuntimeClientBuilder + pub fn new( + url: Url, + max_retry: u32, + timeout_retry: u32, + initial_backoff: u64, + timeout: Duration, + compute_units_per_second: u64, + ) -> Self { + Self { + url, + max_retry, + timeout, + timeout_retry, + initial_backoff, + compute_units_per_second, + jwt: None, + headers: vec![], + } + } + + /// Set jwt to use with RuntimeClient + pub fn with_jwt(mut self, jwt: Option) -> Self { + self.jwt = jwt; + self + } + + /// Set http headers to use with RuntimeClient + /// Only works with http/https schemas + pub fn with_headers(mut self, headers: Vec) -> Self { + self.headers = headers; + self + } + + /// Builds RuntimeClient instance + pub fn build(self) -> RuntimeClient { + RuntimeClient { + client: Arc::new(RwLock::new(None)), + url: self.url, + max_retry: self.max_retry, + timeout_retry: self.timeout_retry, + initial_backoff: self.initial_backoff, + timeout: self.timeout, + compute_units_per_second: self.compute_units_per_second, + jwt: self.jwt, + headers: self.headers, + } + } +} + #[cfg(windows)] fn url_to_file_path(url: &Url) -> Result { const PREFIX: &str = "file:///pipe/";