Skip to content

Commit

Permalink
Cache get_decimals for every oracle (#345)
Browse files Browse the repository at this point in the history
- Share token decimal cache for every oracle
- Related to #343
  • Loading branch information
Uxio0 authored Sep 5, 2022
1 parent 9302d5e commit ed83837
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 57 deletions.
52 changes: 30 additions & 22 deletions gnosis/eth/oracles/oracles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple

import requests
from eth_abi.exceptions import DecodingError
Expand Down Expand Up @@ -77,6 +77,31 @@ def get_underlying_tokens(self, *args) -> List[Tuple[UnderlyingToken]]:
pass


@functools.lru_cache(maxsize=10_000)
def get_decimals(
token_address: ChecksumAddress, ethereum_client: EthereumClient
) -> int:
"""
Auxiliary function so RPC call to get `token decimals` is cached and
can be reused for every Oracle, instead of having a cache per Oracle.
:param token_address:
:param ethereum_client:
:return: Decimals for a token
:raises CannotGetPriceFromOracle: If there's a problem with the query
"""
try:
return ethereum_client.erc20.get_decimals(token_address)
except (
ValueError,
BadFunctionCallOutput,
DecodingError,
) as e:
error_message = f"Cannot get decimals for token={token_address}"
logger.warning(error_message)
raise CannotGetPriceFromOracle(error_message) from e


class KyberOracle(PriceOracle):
# This is the `tokenAddress` they use for ETH ¯\_(ツ)_/¯
ETH_TOKEN_ADDRESS = "0xEeeeeEeeeEeEeeEeEeEeeEEEeeeeEeeeeeeeEEeE"
Expand Down Expand Up @@ -320,7 +345,6 @@ def __init__(
self.router = get_uniswap_v2_router_contract(
ethereum_client.w3, self.router_address
)
self._decimals_cache: Dict[str, int] = {}

@cached_property
def factory(self):
Expand Down Expand Up @@ -388,24 +412,6 @@ def calculate_pair_address(self, token_address: str, token_address_2: str):
)[-20:]
return fast_bytes_to_checksum_address(address)

def get_decimals(self, token_address: str, token_address_2: str) -> Tuple[int, int]:
if not (
token_address in self._decimals_cache
and token_address_2 in self._decimals_cache
):
decimals_1, decimals_2 = self.ethereum_client.batch_call(
[
get_erc20_contract(self.w3, token_address).functions.decimals(),
get_erc20_contract(self.w3, token_address_2).functions.decimals(),
]
)
self._decimals_cache[token_address] = decimals_1
self._decimals_cache[token_address_2] = decimals_2
return (
self._decimals_cache[token_address],
self._decimals_cache[token_address_2],
)

def get_reserves(self, pair_address: str) -> Tuple[int, int]:
"""
Returns the number of tokens in the pool. `getReserves()` also returns the block.timestamp (mod 2**32) of
Expand Down Expand Up @@ -439,7 +445,8 @@ def get_price(
pair_address = self.calculate_pair_address(token_address, token_address_2)
# Tokens are sorted, so token_1 < token_2
reserves_1, reserves_2 = self.get_reserves(pair_address)
decimals_1, decimals_2 = self.get_decimals(token_address, token_address_2)
decimals_1 = get_decimals(token_address, self.ethereum_client)
decimals_2 = get_decimals(token_address_2, self.ethereum_client)
if token_address.lower() > token_address_2.lower():
reserves_2, reserves_1 = reserves_1, reserves_2

Expand Down Expand Up @@ -504,7 +511,8 @@ def get_pool_token_price(self, pool_token_address: ChecksumAddress) -> float:
pair_contract.functions.totalSupply(),
]
)
decimals_1, decimals_2 = self.get_decimals(token_address_1, token_address_2)
decimals_1 = get_decimals(token_address_1, self.ethereum_client)
decimals_2 = get_decimals(token_address_2, self.ethereum_client)

# Total value for one token should be the same than total value for the other token
# if pool is under active arbitrage. We use the price for the first token we find
Expand Down
22 changes: 3 additions & 19 deletions gnosis/eth/oracles/uniswap_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from .. import EthereumClient
from ..constants import NULL_ADDRESS
from ..contracts import get_erc20_contract
from .abis.uniswap_v3 import (
uniswap_v3_factory_abi,
uniswap_v3_pool_abi,
uniswap_v3_router_abi,
)
from .oracles import CannotGetPriceFromOracle, PriceOracle
from .oracles import CannotGetPriceFromOracle, PriceOracle, get_decimals

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,21 +57,6 @@ def is_available(
uniswap_v3_router_address or cls.UNISWAP_V3_ROUTER
)

@functools.lru_cache(maxsize=5120)
def get_decimals(self, token_address: str) -> int:
try:
return (
get_erc20_contract(self.w3, token_address).functions.decimals().call()
)
except (
ValueError,
BadFunctionCallOutput,
DecodingError,
) as e:
error_message = f"Cannot get decimals for token={token_address}"
logger.warning(error_message)
raise CannotGetPriceFromOracle(error_message) from e

def get_factory(self) -> Contract:
"""
Factory contract creates the pools for token pairs
Expand Down Expand Up @@ -169,8 +153,8 @@ def get_price(
raise CannotGetPriceFromOracle(error_message) from e

# Decimals needs to be adjusted
token_decimals = self.get_decimals(token_address)
token_2_decimals = self.get_decimals(token_address_2)
token_decimals = get_decimals(token_address, self.ethereum_client)
token_2_decimals = get_decimals(token_address_2, self.ethereum_client)

# https://docs.uniswap.org/sdk/guides/fetching-prices
if not reversed:
Expand Down
56 changes: 40 additions & 16 deletions gnosis/eth/tests/test_oracles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from eth_account import Account
from eth_typing import ChecksumAddress

from .. import EthereumClient
from ..oracles import (
Expand All @@ -23,6 +24,7 @@
YearnOracle,
ZerionComposedOracle,
)
from ..oracles.oracles import get_decimals as oracles_get_decimals
from .ethereum_test_case import EthereumTestCaseMixin
from .utils import just_test_if_mainnet_node

Expand Down Expand Up @@ -129,32 +131,36 @@ def test_calculate_pair_address(self, factory_address_mock: MagicMock):
)

def test_get_price(self):
oracles_get_decimals.cache_clear()
mainnet_node = just_test_if_mainnet_node()
ethereum_client = EthereumClient(mainnet_node)
uniswap_v2_oracle = UniswapV2Oracle(ethereum_client)

self.assertEqual(oracles_get_decimals.cache_info().currsize, 0)

price = uniswap_v2_oracle.get_price(
gno_token_mainnet_address, weth_token_mainnet_address
)
self.assertLess(price, 1)
self.assertGreater(price, 0)

self.assertEqual(oracles_get_decimals.cache_info().currsize, 2)

# Test with 2 stablecoins
price = uniswap_v2_oracle.get_price(
dai_token_mainnet_address, usdt_token_mainnet_address
)
self.assertAlmostEqual(price, 1.0, delta=0.5)
self.assertEqual(
uniswap_v2_oracle._decimals_cache[dai_token_mainnet_address], 18
)
self.assertEqual(
uniswap_v2_oracle._decimals_cache[usdt_token_mainnet_address], 6
)
self.assertEqual(oracles_get_decimals.cache_info().currsize, 4)
self.assertEqual(oracles_get_decimals.cache_info().hits, 0)

price = uniswap_v2_oracle.get_price(
usdt_token_mainnet_address, dai_token_mainnet_address
)
self.assertAlmostEqual(price, 1.0, delta=0.5)
self.assertEqual(oracles_get_decimals.cache_info().currsize, 4)
self.assertEqual(oracles_get_decimals.cache_info().hits, 2)
oracles_get_decimals.cache_clear()

def test_get_price_contract_not_deployed(self):
uniswap_v2_oracle = UniswapV2Oracle(self.ethereum_client)
Expand All @@ -165,40 +171,49 @@ def test_get_price_contract_not_deployed(self):
):
uniswap_v2_oracle.get_price(random_token_address)

@mock.patch("gnosis.eth.oracles.oracles.get_decimals", autospec=True)
@mock.patch.object(
UniswapV2Oracle,
"factory_address",
return_value="0x5C69bEe701ef814a2B6a3EDD4B1652CB9cc5aA6f",
new_callable=mock.PropertyMock,
)
@mock.patch.object(
UniswapV2Oracle, "get_decimals", return_value=(18, 3), autospec=True
)
@mock.patch.object(
UniswapV2Oracle, "get_reserves", return_value=(int(1e20), 600), autospec=True
)
def test_get_price_liquidity(
self,
get_reserves_mock: MagicMock,
get_decimals_mock: MagicMock,
factory_address_mock: MagicMock,
get_decimals_mock: MagicMock,
):
uniswap_v2_oracle = UniswapV2Oracle(self.ethereum_client)
token_1, token_2 = (
"0xA14F6F8867DB84a45BCD148bfaf4e4f54B4B9b12",
"0xC426A8F4C79EF274Ed93faC9e1A09bFC5659B06B",
)

def get_decimals_mock_fn(
token_address: ChecksumAddress, ethereum_client: EthereumClient
) -> int:
if token_address == token_1:
return 18
else:
return 3

get_decimals_mock.side_effect = get_decimals_mock_fn
uniswap_v2_oracle = UniswapV2Oracle(self.ethereum_client)

with self.assertRaisesMessage(CannotGetPriceFromOracle, "Not enough liquidity"):
uniswap_v2_oracle.get_price(token_1, token_2)

get_reserves_mock.return_value = (int(1e20), 6000)
self.assertEqual(uniswap_v2_oracle.get_price(token_1, token_2), 0.06)

get_reserves_mock.return_value = reversed(get_reserves_mock.return_value)
self.assertEqual(
uniswap_v2_oracle.get_price(token_2, token_1), 0.06
) # Reserves were inverted
with self.assertRaisesMessage(CannotGetPriceFromOracle, "Not enough liquidity"):
self.assertEqual(
uniswap_v2_oracle.get_price(token_2, token_1), 0.06
) # Reserves were inverted

def test_get_pool_token_price(self):
dai_eth_pool_address = "0xA478c2975Ab1Ea89e8196811F51A7B7Ade33eB11"
Expand All @@ -212,6 +227,7 @@ def test_get_pool_token_price(self):

class TestSushiSwapOracle(EthereumTestCaseMixin, TestCase):
def test_get_price(self):
oracles_get_decimals.cache_clear()
mainnet_node = just_test_if_mainnet_node()
ethereum_client = EthereumClient(mainnet_node)
sushiswap_oracle = SushiswapOracle(ethereum_client)
Expand All @@ -220,23 +236,31 @@ def test_get_price(self):
wbtc_token_mainnet_address, weth_token_mainnet_address
)
self.assertGreater(price, 0)
self.assertEqual(oracles_get_decimals.cache_info().currsize, 2)

# Test with 2 stablecoins
price = sushiswap_oracle.get_price(
dai_token_mainnet_address, usdt_token_mainnet_address
)
self.assertAlmostEqual(price, 1.0, delta=0.5)
self.assertEqual(oracles_get_decimals.cache_info().currsize, 4)

self.assertEqual(oracles_get_decimals.cache_info().hits, 0)
self.assertEqual(
sushiswap_oracle._decimals_cache[dai_token_mainnet_address], 18
oracles_get_decimals(dai_token_mainnet_address, ethereum_client), 18
)
self.assertEqual(
sushiswap_oracle._decimals_cache[usdt_token_mainnet_address], 6
oracles_get_decimals(usdt_token_mainnet_address, ethereum_client), 6
)
self.assertEqual(oracles_get_decimals.cache_info().hits, 2)

price = sushiswap_oracle.get_price(
usdt_token_mainnet_address, dai_token_mainnet_address
)
self.assertAlmostEqual(price, 1.0, delta=0.5)
self.assertEqual(oracles_get_decimals.cache_info().currsize, 4)
self.assertEqual(oracles_get_decimals.cache_info().hits, 4)
oracles_get_decimals.cache_clear()


class TestAaveOracle(EthereumTestCaseMixin, TestCase):
Expand Down

0 comments on commit ed83837

Please sign in to comment.