diff --git a/src/lib.rs b/src/lib.rs index d3b97ba92..afacfd1ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub mod providers { pub mod debug_provider; pub mod eth_provider; pub mod pool_provider; + pub mod sn_provider; } pub mod client; pub mod config; diff --git a/src/pool/mempool.rs b/src/pool/mempool.rs index 492e1e1f4..ce411a0b6 100644 --- a/src/pool/mempool.rs +++ b/src/pool/mempool.rs @@ -1,14 +1,5 @@ use super::validate::KakarotTransactionValidator; -use crate::{ - client::EthClient, - into_via_wrapper, - models::felt::Felt252Wrapper, - providers::eth_provider::{ - error::ExecutionError, - starknet::{ERC20Reader, STARKNET_NATIVE_TOKEN}, - utils::{class_hash_not_declared, contract_not_found}, - }, -}; +use crate::{client::EthClient, providers::sn_provider::StarknetProvider}; use reth_primitives::{BlockId, U256}; use reth_transaction_pool::{ blobstore::NoopBlobStore, CoinbaseTipOrdering, EthPooledTransaction, Pool, TransactionPool, @@ -17,7 +8,6 @@ use serde_json::Value; use starknet::core::types::Felt; use std::{collections::HashMap, fs::File, io::Read, sync::Arc, time::Duration}; use tokio::{runtime::Handle, sync::Mutex}; -use tracing::Instrument; /// A type alias for the Kakarot Transaction Validator. /// Uses the Reth implementation [`TransactionValidationTaskExecutor`]. @@ -106,7 +96,7 @@ impl AccountM for account_address in account_addresses { // Fetch the balance and handle errors functionally let balance = self - .get_balance(&account_address) + .get_balance(account_address) .await .inspect_err(|err| { tracing::error!( @@ -132,32 +122,13 @@ impl AccountM } /// Retrieves the balance of the specified account address. - async fn get_balance(&self, account_address: &Felt) -> eyre::Result { + async fn get_balance(&self, account_address: Felt) -> eyre::Result { // Convert the optional Ethereum block ID to a Starknet block ID. let starknet_block_id = self.eth_client.eth_provider().to_starknet_block_id(Some(BlockId::default())).await?; - - // Create a new `ERC20Reader` instance for the Starknet native token - let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, self.eth_client.eth_provider().starknet_provider()); - - // Call the `balanceOf` method on the contract for the given account_address and block ID, awaiting the result - let span = tracing::span!(tracing::Level::INFO, "sn::balance"); - let res = eth_contract.balanceOf(account_address).block_id(starknet_block_id).call().instrument(span).await; - - if contract_not_found(&res) || class_hash_not_declared(&res) { - return Err(eyre::eyre!("Contract not found or class hash not declared")); - } - - // Otherwise, extract the balance from the result, converting any errors to ExecutionError - let balance = res.map_err(ExecutionError::from)?.balance; - - // Convert the low and high parts of the balance to U256 - let low: U256 = into_via_wrapper!(balance.low); - let high: U256 = into_via_wrapper!(balance.high); - - // Combine the low and high parts to form the final balance and return it - let balance = low + (high << 128); - - Ok(balance) + // Create a new Starknet provider wrapper. + let starknet_provider = StarknetProvider::new(Arc::new(self.eth_client.eth_provider().starknet_provider())); + // Get the balance of the address at the given block ID. + starknet_provider.balance_at(account_address, starknet_block_id).await.map_err(Into::into) } /// Processes a transaction for the given account if the balance is sufficient. diff --git a/src/providers/eth_provider/state.rs b/src/providers/eth_provider/state.rs index 20cb1095b..8f565abde 100644 --- a/src/providers/eth_provider/state.rs +++ b/src/providers/eth_provider/state.rs @@ -1,18 +1,20 @@ +use std::sync::Arc; + use super::{ database::state::{EthCacheDatabase, EthDatabase}, error::{EthApiError, ExecutionError, TransactionError}, - starknet::{ - kakarot_core::{account_contract::AccountContractReader, starknet_address}, - ERC20Reader, STARKNET_NATIVE_TOKEN, - }, - utils::{class_hash_not_declared, contract_not_found, entrypoint_not_found, split_u256}, + starknet::kakarot_core::{account_contract::AccountContractReader, starknet_address}, + utils::{contract_not_found, entrypoint_not_found, split_u256}, }; use crate::{ into_via_wrapper, models::felt::Felt252Wrapper, - providers::eth_provider::{ - provider::{EthApiResult, EthDataProvider}, - BlockProvider, ChainProvider, + providers::{ + eth_provider::{ + provider::{EthApiResult, EthDataProvider}, + BlockProvider, ChainProvider, + }, + sn_provider::StarknetProvider, }, }; use async_trait::async_trait; @@ -70,35 +72,10 @@ where async fn balance(&self, address: Address, block_id: Option) -> EthApiResult { // Convert the optional Ethereum block ID to a Starknet block ID. let starknet_block_id = self.to_starknet_block_id(block_id).await?; - - // Create a new `ERC20Reader` instance for the Starknet native token - let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, self.starknet_provider()); - - // Call the `balanceOf` method on the contract for the given address and block ID, awaiting the result - let span = tracing::span!(tracing::Level::INFO, "sn::balance"); - let res = eth_contract - .balanceOf(&starknet_address(address)) - .block_id(starknet_block_id) - .call() - .instrument(span) - .await; - - // Check if the contract was not found or the class hash not declared, - // returning a default balance of 0 if true. - // The native token contract should be deployed on Kakarot, so this should not happen - // We want to avoid errors in this case and return a default balance of 0 - if contract_not_found(&res) || class_hash_not_declared(&res) { - return Ok(Default::default()); - } - // Otherwise, extract the balance from the result, converting any errors to ExecutionError - let balance = res.map_err(ExecutionError::from)?.balance; - - // Convert the low and high parts of the balance to U256 - let low: U256 = into_via_wrapper!(balance.low); - let high: U256 = into_via_wrapper!(balance.high); - - // Combine the low and high parts to form the final balance and return it - Ok(low + (high << 128)) + // Create a new Starknet provider wrapper. + let starknet_provider = StarknetProvider::new(Arc::new(self.starknet_provider())); + // Get the balance of the address at the given block ID. + starknet_provider.balance_at(starknet_address(address), starknet_block_id).await.map_err(Into::into) } async fn storage_at( diff --git a/src/providers/sn_provider/mod.rs b/src/providers/sn_provider/mod.rs new file mode 100644 index 000000000..1b3e6d24a --- /dev/null +++ b/src/providers/sn_provider/mod.rs @@ -0,0 +1,3 @@ +pub mod starknet_provider; + +pub use starknet_provider::StarknetProvider; diff --git a/src/providers/sn_provider/starknet_provider.rs b/src/providers/sn_provider/starknet_provider.rs new file mode 100644 index 000000000..cc961e9fe --- /dev/null +++ b/src/providers/sn_provider/starknet_provider.rs @@ -0,0 +1,63 @@ +use crate::{ + into_via_wrapper, + models::felt::Felt252Wrapper, + providers::eth_provider::{ + error::ExecutionError, + starknet::{ERC20Reader, STARKNET_NATIVE_TOKEN}, + utils::{class_hash_not_declared, contract_not_found}, + }, +}; +use reth_primitives::U256; +use starknet::core::types::{BlockId, Felt}; +use std::sync::Arc; +use tracing::Instrument; + +/// A provider wrapper around the Starknet provider to expose utility methods. +#[derive(Debug, Clone)] +pub struct StarknetProvider { + /// The underlying Starknet provider wrapped in an [`Arc`] for shared ownership across threads. + provider: Arc, +} + +impl StarknetProvider +where + SP: starknet::providers::Provider + Send + Sync, +{ + /// Creates a new [`StarknetProvider`] instance from an [`Arc`]-wrapped Starknet provider. + pub const fn new(provider: Arc) -> Self { + Self { provider } + } + + /// Retrieves the balance of a Starknet address for a specified block. + /// + /// This method interacts with the Starknet native token contract to query the balance of the given + /// address at a specific block. + /// + /// If the contract is not deployed or the class hash is not declared, a balance of 0 is returned + /// instead of an error. + pub async fn balance_at(&self, address: Felt, block_id: BlockId) -> Result { + // Create a new `ERC20Reader` instance for the Starknet native token + let eth_contract = ERC20Reader::new(*STARKNET_NATIVE_TOKEN, &self.provider); + + // Call the `balanceOf` method on the contract for the given address and block ID, awaiting the result + let span = tracing::span!(tracing::Level::INFO, "sn::balance"); + let res = eth_contract.balanceOf(&address).block_id(block_id).call().instrument(span).await; + + // Check if the contract was not found or the class hash not declared, + // returning a default balance of 0 if true. + // The native token contract should be deployed on Kakarot, so this should not happen + // We want to avoid errors in this case and return a default balance of 0 + if contract_not_found(&res) || class_hash_not_declared(&res) { + return Ok(Default::default()); + } + // Otherwise, extract the balance from the result, converting any errors to ExecutionError + let balance = res.map_err(ExecutionError::from)?.balance; + + // Convert the low and high parts of the balance to U256 + let low: U256 = into_via_wrapper!(balance.low); + let high: U256 = into_via_wrapper!(balance.high); + + // Combine the low and high parts to form the final balance and return it + Ok(low + (high << 128)) + } +}