From 89d9a887fb4d618f79ac5056013ccc5e0e87ad59 Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 10 Oct 2024 11:16:52 -0600 Subject: [PATCH] Add support for non-mainnet request cache validation thresholds: - If non-mainnet ethereum, allow the use of an integer value for the request cache validation threshold. Pre-configure "safe" default values for some common non-mainnet chain ids based on their varied finality mechanisms. - This integer value represents the number of seconds from time.now() that the request cache deems as a safe enough time window to allow the request to be cached. - Update the tests to reflect these changes. Bonus: - Add a note to the ``request_mocker`` to make sure that mocked results are of the correct expected types based on JSON-RPC spec (e.g. hex strings instead of ints for numbers, etc.) --- .../caching-utils/test_request_caching.py | 173 +++++++++------ web3/_utils/caching/caching_utils.py | 153 +++++++++++-- .../caching/request_caching_validation.py | 210 ++++++++++++++---- web3/_utils/module_testing/utils.py | 8 + web3/providers/async_base.py | 9 +- web3/providers/base.py | 9 +- 6 files changed, 431 insertions(+), 131 deletions(-) diff --git a/tests/core/caching-utils/test_request_caching.py b/tests/core/caching-utils/test_request_caching.py index c48c03f0a6..abbf917dfe 100644 --- a/tests/core/caching-utils/test_request_caching.py +++ b/tests/core/caching-utils/test_request_caching.py @@ -21,9 +21,11 @@ ) from web3._utils.caching.caching_utils import ( ASYNC_INTERNAL_VALIDATION_MAP, + BLOCK_IN_RESULT, BLOCKHASH_IN_PARAMS, BLOCKNUM_IN_PARAMS, - BLOCKNUM_IN_RESULT, + CHAIN_VALIDATION_THRESHOLD_DEFAULTS, + DEFAULT_VALIDATION_THRESHOLD, INTERNAL_VALIDATION_MAP, ) from web3.exceptions import ( @@ -64,6 +66,7 @@ def w3(request_mocker): mock_results={ "fake_endpoint": lambda *_: uuid.uuid4(), "not_on_allowlist": lambda *_: uuid.uuid4(), + "eth_chainId": "0x1", # mainnet }, ): yield _w3 @@ -134,10 +137,10 @@ def test_caching_requests_does_not_share_state_between_providers(request_mocker) # strap w3_a_shared_cache with w3_a's cache w3_a_shared_cache.provider._request_cache = w3_a.provider._request_cache - mock_results_a = {RPCEndpoint("eth_chainId"): 11111} - mock_results_a_shared_cache = {RPCEndpoint("eth_chainId"): 00000} - mock_results_b = {RPCEndpoint("eth_chainId"): 22222} - mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + mock_results_a = {RPCEndpoint("eth_chainId"): hex(11111)} + mock_results_a_shared_cache = {RPCEndpoint("eth_chainId"): hex(00000)} + mock_results_b = {RPCEndpoint("eth_chainId"): hex(22222)} + mock_results_c = {RPCEndpoint("eth_chainId"): hex(33333)} with request_mocker(w3_a, mock_results=mock_results_a): with request_mocker(w3_b, mock_results=mock_results_b): @@ -154,10 +157,10 @@ def test_caching_requests_does_not_share_state_between_providers(request_mocker) "eth_chainId", [] ) - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 - assert result_a_shared_cache == 11111 + assert result_a == hex(11111) + assert result_b == hex(22222) + assert result_c == hex(33333) + assert result_a_shared_cache == hex(11111) @pytest.mark.parametrize( @@ -199,7 +202,7 @@ def test_all_providers_do_not_cache_by_default_and_can_set_caching_properties(pr "threshold", (RequestCacheValidationThreshold.FINALIZED, RequestCacheValidationThreshold.SAFE), ) -@pytest.mark.parametrize("endpoint", BLOCKNUM_IN_PARAMS | BLOCKNUM_IN_RESULT) +@pytest.mark.parametrize("endpoint", BLOCKNUM_IN_PARAMS | BLOCK_IN_RESULT) @pytest.mark.parametrize( "blocknum,should_cache", ( @@ -211,11 +214,11 @@ def test_all_providers_do_not_cache_by_default_and_can_set_caching_properties(pr ("0x5", False), ), ) -def test_blocknum_validation_against_validation_threshold_when_caching( +def test_blocknum_validation_against_validation_threshold_when_caching_mainnet( threshold, endpoint, blocknum, should_cache, request_mocker ): w3 = Web3( - HTTPProvider( + BaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) @@ -224,7 +227,7 @@ def test_blocknum_validation_against_validation_threshold_when_caching( mock_results={ endpoint: ( # mock the result to requests that return blocks - {"number": blocknum} + {"number": blocknum, "timestamp": "0x0"} if "getBlock" in endpoint # mock the result to requests that return transactions else {"blockNumber": blocknum} @@ -232,16 +235,17 @@ def test_blocknum_validation_against_validation_threshold_when_caching( "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2", return # blocknum otherwise - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": params[0]} + else {"number": params[0], "timestamp": "0x0"} ), + "eth_chainId": "0x1", # mainnet }, ): assert len(w3.provider._request_cache.items()) == 0 w3.manager.request_blocking(endpoint, [blocknum, False]) cached_items = len(w3.provider._request_cache.items()) - assert cached_items == 1 if should_cache else cached_items == 0 + assert cached_items > 0 if should_cache else cached_items == 0 @pytest.mark.parametrize( @@ -260,30 +264,31 @@ def test_blocknum_validation_against_validation_threshold_when_caching( ("pending", None, False), ), ) -def test_block_id_param_caching( +def test_block_id_param_caching_mainnet( threshold, endpoint, block_id, blocknum, should_cache, request_mocker ): w3 = Web3( - HTTPProvider( + BaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) with request_mocker( w3, mock_results={ + "eth_chainId": "0x1", # mainnet endpoint: "0x0", "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2" for all test cases - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": blocknum} + else {"number": blocknum, "timestamp": "0x0"} ), }, ): assert len(w3.provider._request_cache.items()) == 0 w3.manager.request_blocking(RPCEndpoint(endpoint), [block_id, False]) cached_items = len(w3.provider._request_cache.items()) - assert cached_items == 1 if should_cache else cached_items == 0 + assert cached_items > 0 if should_cache else cached_items == 0 @pytest.mark.parametrize( @@ -302,24 +307,25 @@ def test_block_id_param_caching( ("0x5", False), ), ) -def test_blockhash_validation_against_validation_threshold_when_caching( +def test_blockhash_validation_against_validation_threshold_when_caching_mainnet( threshold, endpoint, blocknum, should_cache, request_mocker ): w3 = Web3( - HTTPProvider( + BaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) with request_mocker( w3, mock_results={ + "eth_chainId": "0x1", # mainnet "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2" - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": params[0]} + else {"number": params[0], "timestamp": "0x0"} ), - "eth_getBlockByHash": {"number": blocknum}, + "eth_getBlockByHash": {"number": blocknum, "timestamp": "0x0"}, endpoint: "0x0", }, ): @@ -329,23 +335,37 @@ def test_blockhash_validation_against_validation_threshold_when_caching( assert cached_items == 2 if should_cache else cached_items == 0 -def test_request_caching_validation_threshold_is_finalized_by_default(): - w3 = Web3(HTTPProvider(cache_allowed_requests=True)) - assert ( - w3.provider.request_cache_validation_threshold - == RequestCacheValidationThreshold.FINALIZED - ) +@pytest.mark.parametrize( + "chain_id,expected_threshold", + ( + *CHAIN_VALIDATION_THRESHOLD_DEFAULTS.items(), + (3456787654567654, DEFAULT_VALIDATION_THRESHOLD), + (11111111111444444444444444, DEFAULT_VALIDATION_THRESHOLD), + (-11111111111111111117, DEFAULT_VALIDATION_THRESHOLD), + ), +) +def test_request_caching_validation_threshold_defaults( + chain_id, expected_threshold, request_mocker +): + w3 = Web3(BaseProvider(cache_allowed_requests=True)) + with request_mocker(w3, mock_results={"eth_chainId": hex(chain_id)}): + w3.manager.request_blocking(RPCEndpoint("eth_chainId"), []) + assert w3.provider.request_cache_validation_threshold == expected_threshold + # assert chain_id is cached + cache_items = w3.provider._request_cache.items() + assert len(cache_items) == 1 + assert cache_items[0][1]["result"] == hex(chain_id) @pytest.mark.parametrize( - "endpoint", BLOCKNUM_IN_PARAMS | BLOCKNUM_IN_RESULT | BLOCKHASH_IN_PARAMS + "endpoint", BLOCKNUM_IN_PARAMS | BLOCK_IN_RESULT | BLOCKHASH_IN_PARAMS ) @pytest.mark.parametrize("blocknum", ("0x0", "0x1", "0x2", "0x3", "0x4", "0x5")) def test_request_caching_with_validation_threshold_set_to_none( endpoint, blocknum, request_mocker ): w3 = Web3( - HTTPProvider( + BaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=None, ) @@ -376,6 +396,7 @@ async def async_w3(request_mocker): mock_results={ "fake_endpoint": lambda *_: uuid.uuid4(), "not_on_allowlist": lambda *_: uuid.uuid4(), + "eth_chainId": "0x1", # mainnet }, ): yield _async_w3 @@ -460,10 +481,10 @@ async def test_async_request_caching_does_not_share_state_between_providers( # strap async_w3_a_shared_cache with async_w3_a's cache async_w3_a_shared_cache.provider._request_cache = async_w3_a.provider._request_cache - mock_results_a = {RPCEndpoint("eth_chainId"): 11111} - mock_results_a_shared_cache = {RPCEndpoint("eth_chainId"): 00000} - mock_results_b = {RPCEndpoint("eth_chainId"): 22222} - mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + mock_results_a = {RPCEndpoint("eth_chainId"): hex(11111)} + mock_results_a_shared_cache = {RPCEndpoint("eth_chainId"): hex(00000)} + mock_results_b = {RPCEndpoint("eth_chainId"): hex(22222)} + mock_results_c = {RPCEndpoint("eth_chainId"): hex(33333)} async with request_mocker(async_w3_a, mock_results=mock_results_a): async with request_mocker(async_w3_b, mock_results=mock_results_b): @@ -480,10 +501,10 @@ async def test_async_request_caching_does_not_share_state_between_providers( "eth_chainId", [] ) - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 - assert result_a_shared_cache == 11111 + assert result_a == hex(11111) + assert result_b == hex(22222) + assert result_c == hex(33333) + assert result_a_shared_cache == hex(11111) @pytest.mark.asyncio @@ -491,7 +512,7 @@ async def test_async_request_caching_does_not_share_state_between_providers( "threshold", (RequestCacheValidationThreshold.FINALIZED, RequestCacheValidationThreshold.SAFE), ) -@pytest.mark.parametrize("endpoint", BLOCKNUM_IN_PARAMS | BLOCKNUM_IN_RESULT) +@pytest.mark.parametrize("endpoint", BLOCKNUM_IN_PARAMS | BLOCK_IN_RESULT) @pytest.mark.parametrize( "blocknum,should_cache", ( @@ -503,11 +524,11 @@ async def test_async_request_caching_does_not_share_state_between_providers( ("0x5", False), ), ) -async def test_async_blocknum_validation_against_validation_threshold( +async def test_async_blocknum_validation_against_validation_threshold_mainnet( threshold, endpoint, blocknum, should_cache, request_mocker ): async_w3 = AsyncWeb3( - AsyncHTTPProvider( + AsyncBaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) @@ -516,7 +537,7 @@ async def test_async_blocknum_validation_against_validation_threshold( mock_results={ endpoint: ( # mock the result to requests that return blocks - {"number": blocknum} + {"number": blocknum, "timestamp": "0x0"} if "getBlock" in endpoint # mock the result to requests that return transactions else {"blockNumber": blocknum} @@ -524,16 +545,17 @@ async def test_async_blocknum_validation_against_validation_threshold( "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2", return # blocknum otherwise - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": params[0]} + else {"number": params[0], "timestamp": "0x0"} ), + "eth_chainId": "0x1", # mainnet }, ): assert len(async_w3.provider._request_cache.items()) == 0 await async_w3.manager.coro_request(endpoint, [blocknum, False]) cached_items = len(async_w3.provider._request_cache.items()) - assert cached_items == 1 if should_cache else cached_items == 0 + assert cached_items > 0 if should_cache else cached_items == 0 @pytest.mark.asyncio @@ -553,30 +575,31 @@ async def test_async_blocknum_validation_against_validation_threshold( ("pending", None, False), ), ) -async def test_async_block_id_param_caching( +async def test_async_block_id_param_caching_mainnet( threshold, endpoint, block_id, blocknum, should_cache, request_mocker ): async_w3 = AsyncWeb3( - AsyncHTTPProvider( + AsyncBaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) async with request_mocker( async_w3, mock_results={ + "eth_chainId": "0x1", # mainnet endpoint: "0x0", "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2" for all test cases - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": blocknum} + else {"number": blocknum, "timestamp": "0x0"} ), }, ): assert len(async_w3.provider._request_cache.items()) == 0 await async_w3.manager.coro_request(RPCEndpoint(endpoint), [block_id, False]) cached_items = len(async_w3.provider._request_cache.items()) - assert cached_items == 1 if should_cache else cached_items == 0 + assert cached_items > 0 if should_cache else cached_items == 0 @pytest.mark.asyncio @@ -596,24 +619,25 @@ async def test_async_block_id_param_caching( ("0x5", False), ), ) -async def test_async_blockhash_validation_against_validation_threshold( +async def test_async_blockhash_validation_against_validation_threshold_mainnet( threshold, endpoint, blocknum, should_cache, request_mocker ): async_w3 = AsyncWeb3( - AsyncHTTPProvider( + AsyncBaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=threshold ) ) async with request_mocker( async_w3, mock_results={ + "eth_chainId": "0x1", # mainnet "eth_getBlockByNumber": lambda _method, params: ( # mock the threshold block to be blocknum "0x2" - {"number": "0x2"} + {"number": "0x2", "timestamp": "0x0"} if params[0] == threshold.value - else {"number": params[0]} + else {"number": params[0], "timestamp": "0x0"} ), - "eth_getBlockByHash": {"number": blocknum}, + "eth_getBlockByHash": {"number": blocknum, "timestamp": "0x0"}, endpoint: "0x0", }, ): @@ -624,24 +648,39 @@ async def test_async_blockhash_validation_against_validation_threshold( @pytest.mark.asyncio -async def test_async_request_caching_validation_threshold_is_finalized_by_default(): - async_w3 = AsyncWeb3(AsyncHTTPProvider(cache_allowed_requests=True)) - assert ( - async_w3.provider.request_cache_validation_threshold - == RequestCacheValidationThreshold.FINALIZED - ) +@pytest.mark.parametrize( + "chain_id,expected_threshold", + ( + *CHAIN_VALIDATION_THRESHOLD_DEFAULTS.items(), + (3456787654567654, DEFAULT_VALIDATION_THRESHOLD), + (11111111111444444444444444, DEFAULT_VALIDATION_THRESHOLD), + (-11111111111111111117, DEFAULT_VALIDATION_THRESHOLD), + ), +) +async def test_async_request_caching_validation_threshold_defaults( + chain_id, expected_threshold, request_mocker +): + async_w3 = AsyncWeb3(AsyncBaseProvider(cache_allowed_requests=True)) + async with request_mocker(async_w3, mock_results={"eth_chainId": hex(chain_id)}): + await async_w3.manager.coro_request(RPCEndpoint("eth_chainId"), []) + assert ( + async_w3.provider.request_cache_validation_threshold == expected_threshold + ) + cache_items = async_w3.provider._request_cache.items() + assert len(cache_items) == 1 + assert cache_items[0][1]["result"] == hex(chain_id) @pytest.mark.asyncio @pytest.mark.parametrize( - "endpoint", BLOCKNUM_IN_PARAMS | BLOCKNUM_IN_RESULT | BLOCKHASH_IN_PARAMS + "endpoint", BLOCKNUM_IN_PARAMS | BLOCK_IN_RESULT | BLOCKHASH_IN_PARAMS ) @pytest.mark.parametrize("blocknum", ("0x0", "0x1", "0x2", "0x3", "0x4", "0x5")) async def test_async_request_caching_with_validation_threshold_set_to_none( endpoint, blocknum, request_mocker ): async_w3 = AsyncWeb3( - AsyncHTTPProvider( + AsyncBaseProvider( cache_allowed_requests=True, request_cache_validation_threshold=None, ) diff --git a/web3/_utils/caching/caching_utils.py b/web3/_utils/caching/caching_utils.py index dc640cee07..1fb0bc5ec4 100644 --- a/web3/_utils/caching/caching_utils.py +++ b/web3/_utils/caching/caching_utils.py @@ -16,6 +16,9 @@ Union, ) +from eth_typing import ( + ChainId, +) from eth_utils import ( is_boolean, is_bytes, @@ -34,12 +37,15 @@ from web3._utils.caching.request_caching_validation import ( UNCACHEABLE_BLOCK_IDS, always_cache_request, - async_validate_blockhash_in_params, - async_validate_blocknum_in_params, - async_validate_blocknum_in_result, - validate_blockhash_in_params, - validate_blocknum_in_params, - validate_blocknum_in_result, + async_validate_from_block_id_in_params, + async_validate_from_blockhash_in_params, + async_validate_from_blocknum_in_result, + validate_from_block_id_in_params, + validate_from_blockhash_in_params, + validate_from_blocknum_in_result, +) +from web3._utils.empty import ( + empty, ) from web3._utils.rpc_abi import ( RPC, @@ -47,6 +53,9 @@ from web3.exceptions import ( Web3TypeError, ) +from web3.utils import ( + RequestCacheValidationThreshold, +) if TYPE_CHECKING: from web3.providers import ( # noqa: F401 @@ -100,6 +109,28 @@ def __init__( self.middleware_response_processors: List[Callable[..., Any]] = [] +DEFAULT_VALIDATION_THRESHOLD = 60 * 60 # 1 hour + +CHAIN_VALIDATION_THRESHOLD_DEFAULTS: Dict[ + int, Union[RequestCacheValidationThreshold, int] +] = { + # Suggested safe values as defaults for each chain. Users can configure a different + # value if desired. + ChainId.ETH.value: RequestCacheValidationThreshold.FINALIZED, + ChainId.ARB1.value: 7 * 24 * 60 * 60, # 7 days + ChainId.ZKSYNC.value: 60 * 60, # 1 hour + ChainId.OETH.value: 3 * 60, # 3 minutes + ChainId.MATIC.value: 30 * 60, # 30 minutes + ChainId.ZKEVM.value: 60 * 60, # 1 hour + ChainId.BASE.value: 7 * 24 * 60 * 60, # 7 days + ChainId.SCR.value: 60 * 60, # 1 hour + ChainId.GNO.value: 5 * 60, # 5 minutes + ChainId.AVAX.value: 2 * 60, # 2 minutes + ChainId.BNB.value: 2 * 60, # 2 minutes + ChainId.FTM.value: 60, # 1 minute +} + + def is_cacheable_request( provider: Union[ASYNC_PROVIDER_TYPE, SYNC_PROVIDER_TYPE], method: "RPCEndpoint", @@ -128,7 +159,7 @@ def is_cacheable_request( RPC.eth_getUncleByBlockNumberAndIndex, RPC.eth_getUncleCountByBlockNumber, } -BLOCKNUM_IN_RESULT = { +BLOCK_IN_RESULT = { RPC.eth_getBlockByHash, RPC.eth_getTransactionByHash, RPC.eth_getTransactionByBlockNumberAndIndex, @@ -142,16 +173,60 @@ def is_cacheable_request( } INTERNAL_VALIDATION_MAP: Dict[ - "RPCEndpoint", Callable[[SYNC_PROVIDER_TYPE, Sequence[Any], Dict[str, Any]], bool] + "RPCEndpoint", + Callable[ + [SYNC_PROVIDER_TYPE, Sequence[Any], Dict[str, Any]], + bool, + ], ] = { **{endpoint: always_cache_request for endpoint in ALWAYS_CACHE}, - **{endpoint: validate_blocknum_in_params for endpoint in BLOCKNUM_IN_PARAMS}, - **{endpoint: validate_blocknum_in_result for endpoint in BLOCKNUM_IN_RESULT}, - **{endpoint: validate_blockhash_in_params for endpoint in BLOCKHASH_IN_PARAMS}, + **{endpoint: validate_from_block_id_in_params for endpoint in BLOCKNUM_IN_PARAMS}, + **{endpoint: validate_from_blocknum_in_result for endpoint in BLOCK_IN_RESULT}, + **{endpoint: validate_from_blockhash_in_params for endpoint in BLOCKHASH_IN_PARAMS}, } CACHEABLE_REQUESTS = tuple(INTERNAL_VALIDATION_MAP.keys()) +def set_threshold_if_empty(provider: SYNC_PROVIDER_TYPE) -> None: + current_threshold = provider.request_cache_validation_threshold + + if current_threshold is empty or isinstance( + current_threshold, RequestCacheValidationThreshold + ): + try: + # turn off momentarily to avoid recursion + provider.cache_allowed_requests = False + chain_id_result = provider.make_request("eth_chainId", [])["result"] + chain_id = int(chain_id_result, 16) + + if ( + isinstance( + current_threshold, + RequestCacheValidationThreshold, + ) + and chain_id != 1 + ): + provider.logger.debug( + "Request cache validation threshold is set to " + f"{current_threshold.value} " + f"for chain with chain_id `{chain_id}` but this value only works " + "on chain_id `1`. Setting to default value for chain_id " + f"`{chain_id}`.", + ) + provider.request_cache_validation_threshold = empty + + if current_threshold is empty: + provider.request_cache_validation_threshold = ( + CHAIN_VALIDATION_THRESHOLD_DEFAULTS.get( + chain_id, DEFAULT_VALIDATION_THRESHOLD + ) + ) + except Exception: + provider.request_cache_validation_threshold = DEFAULT_VALIDATION_THRESHOLD + finally: + provider.cache_allowed_requests = True + + def _should_cache_response( provider: SYNC_PROVIDER_TYPE, method: "RPCEndpoint", @@ -161,6 +236,8 @@ def _should_cache_response( result = response.get("result", None) if "error" in response or is_null(result): return False + + set_threshold_if_empty(provider) if ( method in INTERNAL_VALIDATION_MAP and provider.request_cache_validation_threshold is not None @@ -206,14 +283,60 @@ def wrapper( ASYNC_INTERNAL_VALIDATION_MAP: Dict["RPCEndpoint", ASYNC_VALIDATOR_TYPE] = { **{endpoint: always_cache_request for endpoint in ALWAYS_CACHE}, - **{endpoint: async_validate_blocknum_in_params for endpoint in BLOCKNUM_IN_PARAMS}, - **{endpoint: async_validate_blocknum_in_result for endpoint in BLOCKNUM_IN_RESULT}, **{ - endpoint: async_validate_blockhash_in_params for endpoint in BLOCKHASH_IN_PARAMS + endpoint: async_validate_from_block_id_in_params + for endpoint in BLOCKNUM_IN_PARAMS + }, + **{ + endpoint: async_validate_from_blocknum_in_result for endpoint in BLOCK_IN_RESULT + }, + **{ + endpoint: async_validate_from_blockhash_in_params + for endpoint in BLOCKHASH_IN_PARAMS }, } +async def async_set_threshold_if_empty(provider: ASYNC_PROVIDER_TYPE) -> None: + current_threshold = provider.request_cache_validation_threshold + + if current_threshold is empty or isinstance( + current_threshold, RequestCacheValidationThreshold + ): + try: + # turn off momentarily to avoid recursion + provider.cache_allowed_requests = False + chain_id_result = await provider.make_request("eth_chainId", []) + chain_id = int(chain_id_result["result"], 16) + + if ( + isinstance( + current_threshold, + RequestCacheValidationThreshold, + ) + and chain_id != 1 + ): + provider.logger.debug( + "Request cache validation threshold is set to " + f"{current_threshold.value} " + f"for chain with chain_id `{chain_id}` but this value only works " + "on chain_id `1`. Setting to default value for chain_id " + f"`{chain_id}`.", + ) + provider.request_cache_validation_threshold = empty + + if current_threshold is empty: + provider.request_cache_validation_threshold = ( + CHAIN_VALIDATION_THRESHOLD_DEFAULTS.get( + chain_id, DEFAULT_VALIDATION_THRESHOLD + ) + ) + except Exception: + provider.request_cache_validation_threshold = DEFAULT_VALIDATION_THRESHOLD + finally: + provider.cache_allowed_requests = True + + async def _async_should_cache_response( provider: ASYNC_PROVIDER_TYPE, method: "RPCEndpoint", @@ -223,6 +346,8 @@ async def _async_should_cache_response( result = response.get("result", None) if "error" in response or is_null(result): return False + + await async_set_threshold_if_empty(provider) if ( method in ASYNC_INTERNAL_VALIDATION_MAP and provider.request_cache_validation_threshold is not None diff --git a/web3/_utils/caching/request_caching_validation.py b/web3/_utils/caching/request_caching_validation.py index c5f963a3ed..b82a0d15fe 100644 --- a/web3/_utils/caching/request_caching_validation.py +++ b/web3/_utils/caching/request_caching_validation.py @@ -1,3 +1,4 @@ +import time from typing import ( TYPE_CHECKING, Any, @@ -7,8 +8,8 @@ Union, ) -from eth_utils import ( - to_int, +from web3.utils import ( + RequestCacheValidationThreshold, ) if TYPE_CHECKING: @@ -31,99 +32,216 @@ def _error_log( ) -def is_beyond_validation_threshold(provider: SYNC_PROVIDER_TYPE, blocknum: int) -> bool: +def always_cache_request(*_args: Any, **_kwargs: Any) -> bool: + return True + + +def is_beyond_validation_threshold( + provider: SYNC_PROVIDER_TYPE, + blocknum: int = None, + block_timestamp: int = None, +) -> bool: try: - # `threshold` is either "finalized" or "safe" - threshold = provider.request_cache_validation_threshold.value - response = provider.make_request("eth_getBlockByNumber", [threshold, False]) - return blocknum <= to_int(hexstr=response["result"]["number"]) + threshold = provider.request_cache_validation_threshold + + if isinstance(threshold, RequestCacheValidationThreshold): + # if mainnet and threshold is "finalized" or "safe" + threshold_block = provider.make_request( + "eth_getBlockByNumber", [threshold.value, False] + )["result"] + # we should have a `blocknum` to compare against + return blocknum <= int(threshold_block["number"], 16) + elif isinstance(threshold, int): + if not block_timestamp: + # if validating via `blocknum` from params, we need to get the timestamp + # for the block with `blocknum`. + block = provider.make_request( + "eth_getBlockByNumber", [hex(blocknum), False] + )["result"] + block_timestamp = int(block["timestamp"], 16) + + # if validating via `block_timestamp` from result, we should have a + # `block_timestamp` to compare against + return block_timestamp <= time.time() - threshold + else: + provider.logger.error( + "Invalid request_cache_validation_threshold value. This should not " + f"have happened. Request not cached.\n threshold: {threshold}" + ) + return False except Exception as e: _error_log(provider, e) return False -def always_cache_request(*_args: Any, **_kwargs: Any) -> bool: - return True - - -def validate_blocknum_in_params( - provider: SYNC_PROVIDER_TYPE, params: Sequence[Any], _result: Dict[str, Any] +def validate_from_block_id_in_params( + provider: SYNC_PROVIDER_TYPE, + params: Sequence[Any], + _result: Dict[str, Any], ) -> bool: block_id = params[0] if block_id == "earliest": # `earliest` should always be cacheable return True - blocknum = to_int(hexstr=block_id) - return is_beyond_validation_threshold(provider, blocknum) + + blocknum = int(block_id, 16) + return is_beyond_validation_threshold(provider, blocknum=blocknum) -def validate_blocknum_in_result( - provider: SYNC_PROVIDER_TYPE, _params: Sequence[Any], result: Dict[str, Any] +def validate_from_blocknum_in_result( + provider: SYNC_PROVIDER_TYPE, + _params: Sequence[Any], + result: Dict[str, Any], ) -> bool: - # `number` if block result, `blockNumber` if transaction result - blocknum = to_int(hexstr=result.get("number", result.get("blockNumber"))) - return is_beyond_validation_threshold(provider, blocknum) + try: + # transaction results + if "blockNumber" in result: + blocknum = result.get("blockNumber") + # make an extra call to get the block values + block = provider.make_request("eth_getBlockByNumber", [blocknum, False])[ + "result" + ] + return is_beyond_validation_threshold( + provider, + blocknum=int(blocknum, 16), + block_timestamp=int(block["timestamp"], 16), + ) + elif "number" in result: + return is_beyond_validation_threshold( + provider, + blocknum=int(result["number"], 16), + block_timestamp=int(result["timestamp"], 16), + ) + else: + provider.logger.error( + "Could not find block number in result. This should not have happened. " + f"Request not cached.\n result: {result}", + ) + return False + except Exception as e: + _error_log(provider, e) + return False -def validate_blockhash_in_params( - provider: SYNC_PROVIDER_TYPE, params: Sequence[Any], _result: Dict[str, Any] +def validate_from_blockhash_in_params( + provider: SYNC_PROVIDER_TYPE, + params: Sequence[Any], + _result: Dict[str, Any], ) -> bool: try: # make an extra call to get the block number from the hash - response = provider.make_request("eth_getBlockByHash", [params[0], False]) + block = provider.make_request("eth_getBlockByHash", [params[0], False])[ + "result" + ] + return is_beyond_validation_threshold( + provider, + blocknum=int(block["number"], 16), + block_timestamp=int(block["timestamp"], 16), + ) except Exception as e: _error_log(provider, e) return False - blocknum = to_int(hexstr=response["result"]["number"]) - return is_beyond_validation_threshold(provider, blocknum) - # -- async -- # async def async_is_beyond_validation_threshold( - provider: ASYNC_PROVIDER_TYPE, blocknum: int + provider: ASYNC_PROVIDER_TYPE, + blocknum: int = None, + block_timestamp: int = None, ) -> bool: try: - # `threshold` is either "finalized" or "safe" - threshold = provider.request_cache_validation_threshold.value - response = await provider.make_request( - "eth_getBlockByNumber", [threshold, False] - ) - return blocknum <= to_int(hexstr=response["result"]["number"]) + threshold = provider.request_cache_validation_threshold + + if isinstance(threshold, RequestCacheValidationThreshold): + # if mainnet and threshold is "finalized" or "safe" + threshold_block = await provider.make_request( + "eth_getBlockByNumber", [threshold.value, False] + ) + # we should have a `blocknum` to compare against + return blocknum <= int(threshold_block["result"]["number"], 16) + elif isinstance(threshold, int): + if not block_timestamp: + block = await provider.make_request( + "eth_getBlockByNumber", [hex(blocknum), False] + ) + block_timestamp = int(block["result"]["timestamp"], 16) + + # if validating via `block_timestamp` from result, we should have a + # `block_timestamp` to compare against + return block_timestamp <= time.time() - threshold + else: + provider.logger.error( + "Invalid request_cache_validation_threshold value. This should not " + f"have happened. Request not cached.\n threshold: {threshold}" + ) + return False except Exception as e: _error_log(provider, e) return False -async def async_validate_blocknum_in_params( - provider: ASYNC_PROVIDER_TYPE, params: Sequence[Any], _result: Dict[str, Any] +async def async_validate_from_block_id_in_params( + provider: ASYNC_PROVIDER_TYPE, + params: Sequence[Any], + _result: Dict[str, Any], ) -> bool: block_id = params[0] if block_id == "earliest": # `earliest` should always be cacheable return True - blocknum = to_int(hexstr=params[0]) - return await async_is_beyond_validation_threshold(provider, blocknum) + + blocknum = int(block_id, 16) + return await async_is_beyond_validation_threshold(provider, blocknum=blocknum) -async def async_validate_blocknum_in_result( - provider: ASYNC_PROVIDER_TYPE, _params: Sequence[Any], result: Dict[str, Any] +async def async_validate_from_blocknum_in_result( + provider: ASYNC_PROVIDER_TYPE, + _params: Sequence[Any], + result: Dict[str, Any], ) -> bool: - # `number` if block result, `blockNumber` if transaction result - blocknum = to_int(hexstr=result.get("number", result.get("blockNumber"))) - return await async_is_beyond_validation_threshold(provider, blocknum) + try: + # transaction results + if "blockNumber" in result: + blocknum = result.get("blockNumber") + # make an extra call to get the block values + block = await provider.make_request( + "eth_getBlockByNumber", [blocknum, False] + ) + return await async_is_beyond_validation_threshold( + provider, + blocknum=int(blocknum, 16), + block_timestamp=int(block["result"]["timestamp"], 16), + ) + elif "number" in result: + return await async_is_beyond_validation_threshold( + provider, + blocknum=int(result["number"], 16), + block_timestamp=int(result["timestamp"], 16), + ) + else: + provider.logger.error( + "Could not find block number in result. This should not have happened. " + f"Request not cached.\n result: {result}", + ) + return False + except Exception as e: + _error_log(provider, e) + return False -async def async_validate_blockhash_in_params( +async def async_validate_from_blockhash_in_params( provider: ASYNC_PROVIDER_TYPE, params: Sequence[Any], _result: Dict[str, Any] ) -> bool: try: + # make an extra call to get the block number from the hash response = await provider.make_request("eth_getBlockByHash", [params[0], False]) + return await async_is_beyond_validation_threshold( + provider, + blocknum=int(response["result"]["number"], 16), + block_timestamp=int(response["result"]["timestamp"], 16), + ) except Exception as e: _error_log(provider, e) return False - - blocknum = to_int(hexstr=response["result"]["number"]) - return await async_is_beyond_validation_threshold(provider, blocknum) diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index b2a4fd5084..01260bb5f3 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -35,6 +35,14 @@ class RequestMocker: Context manager to mock requests made by a web3 instance. This is meant to be used via a ``request_mocker`` fixture defined within the appropriate context. + ************************************************************************************ + Important: When mocking results, it's important to keep in mind the types that + clients return. For example, what we commonly translate to integers are returned + as hex strings in the RPC response and should be mocked as such for more + accurate testing. + ************************************************************************************ + + Example: ------- diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 5a4f3f415f..c09756c83a 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -10,6 +10,7 @@ Optional, Set, Tuple, + Union, cast, ) @@ -23,6 +24,10 @@ CACHEABLE_REQUESTS, async_handle_request_caching, ) +from web3._utils.empty import ( + Empty, + empty, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -88,8 +93,8 @@ def __init__( cache_allowed_requests: bool = False, cacheable_requests: Set[RPCEndpoint] = None, request_cache_validation_threshold: Optional[ - RequestCacheValidationThreshold - ] = RequestCacheValidationThreshold.FINALIZED, + Union[RequestCacheValidationThreshold, int, Empty] + ] = empty, ) -> None: self._request_cache = SimpleCache(1000) self.cache_allowed_requests = cache_allowed_requests diff --git a/web3/providers/base.py b/web3/providers/base.py index b653f267bc..d165c8a61a 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -9,6 +9,7 @@ Optional, Set, Tuple, + Union, cast, ) @@ -21,6 +22,10 @@ CACHEABLE_REQUESTS, handle_request_caching, ) +from web3._utils.empty import ( + Empty, + empty, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -71,8 +76,8 @@ def __init__( cache_allowed_requests: bool = False, cacheable_requests: Set[RPCEndpoint] = None, request_cache_validation_threshold: Optional[ - RequestCacheValidationThreshold - ] = RequestCacheValidationThreshold.FINALIZED, + Union[RequestCacheValidationThreshold, int, Empty] + ] = empty, ) -> None: self._request_cache = SimpleCache(1000) self.cache_allowed_requests = cache_allowed_requests