Skip to content

Commit

Permalink
feat(anvil): allow pass headers to fork-url (#6178)
Browse files Browse the repository at this point in the history
* feat(anvil): allow pass headers to `fork-url`

* fix(clippy)

* chore: formatting

* fix: don't use expect

* touchups

---------

Co-authored-by: Matthias Seitz <matthias.seitz@outlook.de>
  • Loading branch information
vbrvk and mattsse committed Nov 3, 2023
1 parent 265059b commit 1c7bf46
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 35 deletions.
29 changes: 29 additions & 0 deletions crates/anvil/src/cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ impl NodeArgs {
.fork_block_number
.or_else(|| self.evm_opts.fork_url.as_ref().and_then(|f| f.block)),
)
.with_fork_headers(self.evm_opts.fork_headers)
.with_fork_chain_id(self.evm_opts.fork_chain_id.map(u64::from))
.fork_request_timeout(self.evm_opts.fork_request_timeout.map(Duration::from_millis))
.fork_request_retries(self.evm_opts.fork_request_retries)
Expand Down Expand Up @@ -328,6 +329,17 @@ pub struct AnvilEvmArgs {
)]
pub fork_url: Option<ForkUrl>,

/// Headers to use for the rpc client, e.g. "User-Agent: test-agent"
///
/// See --fork-url.
#[clap(
long = "fork-header",
value_name = "HEADERS",
help_heading = "Fork config",
requires = "fork_url"
)]
pub fork_headers: Vec<String>,

/// Timeout in ms for requests sent to remote JSON-RPC server in forking mode.
///
/// Default value 45000
Expand Down Expand Up @@ -664,6 +676,23 @@ mod tests {
assert_eq!(args.hardfork, Some(Hardfork::Berlin));
}

#[test]
fn can_parse_fork_headers() {
let args: NodeArgs = NodeArgs::parse_from([
"anvil",
"--fork-url",
"http,://localhost:8545",
"--fork-header",
"User-Agent: test-agent",
"--fork-header",
"Referrer: example.com",
]);
assert_eq!(
args.evm_opts.fork_headers,
vec!["User-Agent: test-agent", "Referrer: example.com"]
);
}

#[test]
fn can_parse_prune_config() {
let args: NodeArgs = NodeArgs::parse_from(["anvil", "--prune-history"]);
Expand Down
11 changes: 11 additions & 0 deletions crates/anvil/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ pub struct NodeConfig {
pub eth_rpc_url: Option<String>,
/// pins the block number for the state fork
pub fork_block_number: Option<u64>,
/// headers to use with `eth_rpc_url`
pub fork_headers: Vec<String>,
/// specifies chain id for cache to skip fetching from remote in offline-start mode
pub fork_chain_id: Option<U256>,
/// The generator used to generate the dev accounts
Expand Down Expand Up @@ -392,6 +394,7 @@ impl Default for NodeConfig {
config_out: None,
genesis: None,
fork_request_timeout: REQUEST_TIMEOUT,
fork_headers: vec![],
fork_request_retries: 5,
fork_retry_backoff: Duration::from_millis(1_000),
fork_chain_id: None,
Expand Down Expand Up @@ -659,6 +662,13 @@ impl NodeConfig {
self
}

/// Sets the `fork_headers` to use with `eth_rpc_url`
#[must_use]
pub fn with_fork_headers(mut self, headers: Vec<String>) -> Self {
self.fork_headers = headers;
self
}

/// Sets the `fork_request_timeout` to use for requests
#[must_use]
pub fn fork_request_timeout(mut self, fork_request_timeout: Option<Duration>) -> Self {
Expand Down Expand Up @@ -901,6 +911,7 @@ impl NodeConfig {
.compute_units_per_second(self.compute_units_per_second)
.max_retry(10)
.initial_backoff(1000)
.headers(self.fork_headers.clone())
.build()
.expect("Failed to establish provider to fork url"),
);
Expand Down
30 changes: 23 additions & 7 deletions crates/common/src/provider.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Commonly used helpers to construct `Provider`s

use crate::{runtime_client::RuntimeClient, ALCHEMY_FREE_TIER_CUPS, REQUEST_TIMEOUT};
use crate::{
runtime_client::{RuntimeClient, RuntimeClientBuilder},
ALCHEMY_FREE_TIER_CUPS, REQUEST_TIMEOUT,
};
use ethers_core::types::{Chain, U256};
use ethers_middleware::gas_oracle::{GasCategory, GasOracle, Polygon};
use ethers_providers::{is_local_endpoint, Middleware, Provider, DEFAULT_LOCAL_POLL_INTERVAL};
Expand Down Expand Up @@ -61,6 +64,7 @@ pub struct ProviderBuilder {
compute_units_per_second: u64,
/// JWT Secret
jwt: Option<String>,
headers: Vec<String>,
}

// === impl ProviderBuilder ===
Expand Down Expand Up @@ -104,6 +108,7 @@ impl ProviderBuilder {
// alchemy max cpus <https://github.com/alchemyplatform/alchemy-docs/blob/master/documentation/compute-units.md#rate-limits-cups>
compute_units_per_second: ALCHEMY_FREE_TIER_CUPS,
jwt: None,
headers: vec![],
}
}

Expand Down Expand Up @@ -188,6 +193,13 @@ impl ProviderBuilder {
self
}

/// Sets http headers
pub fn headers(mut self, headers: Vec<String>) -> Self {
self.headers = headers;

self
}

/// Same as [`Self:build()`] but also retrieves the `chainId` in order to derive an appropriate
/// interval.
pub async fn connect(self) -> Result<RetryProvider> {
Expand All @@ -211,20 +223,24 @@ impl ProviderBuilder {
timeout,
compute_units_per_second,
jwt,
headers,
} = self;
let url = url?;

let is_local = is_local_endpoint(url.as_str());

let mut provider = Provider::new(RuntimeClient::new(
url,
let client_builder = RuntimeClientBuilder::new(
url.clone(),
max_retry,
timeout_retry,
initial_backoff,
timeout,
compute_units_per_second,
jwt,
));
)
.with_headers(headers)
.with_jwt(jwt);

let mut provider = Provider::new(client_builder.build());

let is_local = is_local_endpoint(url.as_str());

if is_local {
provider = provider.interval(DEFAULT_LOCAL_POLL_INTERVAL);
Expand Down
116 changes: 88 additions & 28 deletions crates/common/src/runtime_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ use ethers_providers::{
JsonRpcError, JwtAuth, JwtKey, ProviderError, PubsubClient, RetryClient, RetryClientBuilder,
RpcError, Ws,
};
use reqwest::{header::HeaderValue, Url};
use reqwest::{
header::{HeaderName, HeaderValue},
Url,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt::Debug, path::PathBuf, sync::Arc, time::Duration};
use std::{fmt::Debug, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::sync::RwLock;

Expand Down Expand Up @@ -39,6 +42,10 @@ pub enum RuntimeClientError {
#[error("URL scheme is not supported: {0}")]
BadScheme(String),

/// Invalid HTTP header
#[error("Invalid HTTP header: {0}")]
BadHeader(String),

/// Invalid file path
#[error("Invalid IPC file path: {0}")]
BadPath(String),
Expand Down Expand Up @@ -82,6 +89,20 @@ pub struct RuntimeClient {
/// available CUPS
compute_units_per_second: u64,
jwt: Option<String>,
headers: Vec<String>,
}

/// Builder for RuntimeClient
pub struct RuntimeClientBuilder {
url: Url,
max_retry: u32,
timeout_retry: u32,
initial_backoff: u64,
timeout: Duration,
/// available CUPS
compute_units_per_second: u64,
jwt: Option<String>,
headers: Vec<String>,
}

impl ::core::fmt::Display for RuntimeClient {
Expand All @@ -104,32 +125,11 @@ fn build_auth(jwt: String) -> eyre::Result<Authorization> {
}

impl RuntimeClient {
/// Creates a new dynamic provider from a URL
pub fn new(
url: Url,
max_retry: u32,
timeout_retry: u32,
initial_backoff: u64,
timeout: Duration,
compute_units_per_second: u64,
jwt: Option<String>,
) -> Self {
Self {
client: Arc::new(RwLock::new(None)),
url,
max_retry,
timeout_retry,
initial_backoff,
timeout,
compute_units_per_second,
jwt,
}
}

async fn connect(&self) -> Result<InnerClient, RuntimeClientError> {
match self.url.scheme() {
"http" | "https" => {
let mut client_builder = reqwest::Client::builder().timeout(self.timeout);
let mut headers = reqwest::header::HeaderMap::new();

if let Some(jwt) = self.jwt.as_ref() {
let auth = build_auth(jwt.clone()).map_err(|err| {
Expand All @@ -142,16 +142,25 @@ impl RuntimeClient {
.expect("Header should be valid string");
auth_value.set_sensitive(true);

let mut headers = reqwest::header::HeaderMap::new();
headers.insert(reqwest::header::AUTHORIZATION, auth_value);

client_builder = client_builder.default_headers(headers);
};

for header in self.headers.iter() {
let make_err = || RuntimeClientError::BadHeader(header.to_string());

let (key, val) = header.split_once(':').ok_or_else(make_err)?;

headers.insert(
HeaderName::from_str(key.trim()).map_err(|_| make_err())?,
HeaderValue::from_str(val.trim()).map_err(|_| make_err())?,
);
}

client_builder = client_builder.default_headers(headers);

let client = client_builder
.build()
.map_err(|e| RuntimeClientError::ProviderError(e.into()))?;

let provider = Http::new_with_client(self.url.clone(), client);

#[allow(clippy::box_default)]
Expand Down Expand Up @@ -190,6 +199,57 @@ impl RuntimeClient {
}
}

impl RuntimeClientBuilder {
/// Create new RuntimeClientBuilder
pub fn new(
url: Url,
max_retry: u32,
timeout_retry: u32,
initial_backoff: u64,
timeout: Duration,
compute_units_per_second: u64,
) -> Self {
Self {
url,
max_retry,
timeout,
timeout_retry,
initial_backoff,
compute_units_per_second,
jwt: None,
headers: vec![],
}
}

/// Set jwt to use with RuntimeClient
pub fn with_jwt(mut self, jwt: Option<String>) -> Self {
self.jwt = jwt;
self
}

/// Set http headers to use with RuntimeClient
/// Only works with http/https schemas
pub fn with_headers(mut self, headers: Vec<String>) -> Self {
self.headers = headers;
self
}

/// Builds RuntimeClient instance
pub fn build(self) -> RuntimeClient {
RuntimeClient {
client: Arc::new(RwLock::new(None)),
url: self.url,
max_retry: self.max_retry,
timeout_retry: self.timeout_retry,
initial_backoff: self.initial_backoff,
timeout: self.timeout,
compute_units_per_second: self.compute_units_per_second,
jwt: self.jwt,
headers: self.headers,
}
}
}

#[cfg(windows)]
fn url_to_file_path(url: &Url) -> Result<PathBuf, ()> {
const PREFIX: &str = "file:///pipe/";
Expand Down

0 comments on commit 1c7bf46

Please sign in to comment.