diff --git a/src/dali/abstract_ccxt_pair_converter_plugin.py b/src/dali/abstract_ccxt_pair_converter_plugin.py index 2dd4ab0c..ba7690e1 100644 --- a/src/dali/abstract_ccxt_pair_converter_plugin.py +++ b/src/dali/abstract_ccxt_pair_converter_plugin.py @@ -247,6 +247,7 @@ DAYS_IN_WEEK: int = 7 MANY_YEARS_IN_THE_FUTURE: relativedelta = relativedelta(years=100) + class AssetPairAndHistoricalPrice(NamedTuple): from_asset: str to_asset: str @@ -513,16 +514,23 @@ def find_historical_bar(self, from_asset: str, to_asset: str, timestamp: datetim def find_historical_bars( self, from_asset: str, to_asset: str, timestamp: datetime, exchange: str, all_bars: bool = False, timespan: str = _MINUTE ) -> Optional[List[HistoricalBar]]: + + # Guard clause to pull from cache + if all_bars: + cached_bundle = self._get_bundle_from_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange)) + if cached_bundle: + # If the last bar in the bundle is within the last week, return the bundle + if (datetime.now(timezone.utc) - cached_bundle[-1].timestamp).total_seconds() <= _TIME_GRANULARITY_STRING_TO_SECONDS[_ONE_WEEK]: + return cached_bundle + # If the last bar in the bundle is older than a week, we need to start from the next millisecond + # We will pull the rest later and add to this bundle + timestamp = cached_bundle[-1].timestamp + timedelta(milliseconds=1) + + # Stage 1 Initialization result: List[HistoricalBar] = [] - retry_count: int = 0 self.__transaction_count += 1 - if timespan in _TIME_GRANULARITY_SET: - if exchange in _NONSTANDARD_GRANULARITY_EXCHANGE_SET: - retry_count = _TIME_GRANULARITY_DICT[exchange].index(timespan) - else: - retry_count = _TIME_GRANULARITY.index(timespan) - else: - raise RP2ValueError("Internal error: Invalid time span passed to find_historical_bars.") + retry_count: int = self._initialize_retry_count(exchange, timespan) + current_exchange: Any = self.__exchanges[exchange] ms_timestamp: int = int(timestamp.timestamp() * _MS_IN_SECOND) csv_pricing: Any = self.__csv_pricing_dict.get(exchange) @@ -566,18 +574,6 @@ def find_historical_bars( within_last_week: bool = False - # Get bundles of bars if they exist, saving us from making a call to the API - if all_bars: - cached_bundle: Optional[List[HistoricalBar]] = self._get_bundle_from_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange)) - if cached_bundle: - result.extend(cached_bundle) - timestamp = cached_bundle[-1].timestamp + timedelta(milliseconds=1) - ms_timestamp = int(timestamp.timestamp() * _MS_IN_SECOND) - - # If the bundle of bars is within the last week, we don't need to pull new optimization data. - if result and (datetime.now(timezone.utc) - result[-1].timestamp).total_seconds() > _TIME_GRANULARITY_STRING_TO_SECONDS[_ONE_WEEK]: - within_last_week = True - while (retry_count < len(_TIME_GRANULARITY_DICT.get(exchange, _TIME_GRANULARITY))) and not within_last_week: timeframe: str = _TIME_GRANULARITY_DICT.get(exchange, _TIME_GRANULARITY)[retry_count] request_count: int = 0 @@ -718,6 +714,8 @@ def find_historical_bars( ) ) elif all_bars: + if cached_bundle: + result = cached_bundle + result self._add_bundle_to_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange), result) break # If historical_data is empty we have hit the end of records and need to return else: @@ -728,6 +726,17 @@ def find_historical_bars( return result + def _initialize_retry_count(self, exchange: str, timespan: str) -> int: + if timespan not in _TIME_GRANULARITY_SET: + raise RP2ValueError(f"Internal error: Invalid time span '{timespan}' passed to find_historical_bars.") + + granularity = _TIME_GRANULARITY_DICT[exchange] if exchange in _NONSTANDARD_GRANULARITY_EXCHANGE_SET else _TIME_GRANULARITY + + if timespan not in granularity: + raise RP2ValueError(f"Internal error: Time span '{timespan}' is not valid for exchange '{exchange}'.") + + return granularity.index(timespan) + def _add_alternative_markets(self, graph: MappedGraph[str], current_markets: Dict[str, List[str]]) -> None: for base_asset, quote_asset in _ALT_MARKET_BY_BASE_DICT.items(): alt_market = base_asset + quote_asset diff --git a/tests/test_abstract_ccxt_pair_converter.py b/tests/test_abstract_ccxt_pair_converter.py index 23d102fe..5c703189 100644 --- a/tests/test_abstract_ccxt_pair_converter.py +++ b/tests/test_abstract_ccxt_pair_converter.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional, Set import pytest from prezzemolo.vertex import Vertex from rp2.rp2_decimal import RP2Decimal +from rp2.rp2_error import RP2ValueError from dali.abstract_ccxt_pair_converter_plugin import ( + _BINANCE, + _COINBASE_PRO, + _ONE_HOUR, + _SIX_HOUR, + _TIME_GRANULARITY, + _TIME_GRANULARITY_DICT, MARKET_PADDING_IN_WEEKS, AbstractCcxtPairConverterPlugin, ) @@ -74,10 +81,11 @@ def unoptimized_graph(self, vertex_list: Dict[str, Vertex[str]]) -> MappedGraph[ @pytest.fixture def historical_bars(self) -> Dict[str, HistoricalBar]: + now_time = datetime.now(timezone.utc) return { MARKET_START: HistoricalBar( duration=timedelta(weeks=1), - timestamp=datetime(2023, 1, 1), + timestamp=now_time, open=RP2Decimal("1.0"), high=RP2Decimal("2.0"), low=RP2Decimal("0.5"), @@ -86,7 +94,7 @@ def historical_bars(self) -> Dict[str, HistoricalBar]: ), ONE_WEEK_EARLIER: HistoricalBar( duration=timedelta(weeks=1), - timestamp=datetime(2023, 1, 1) - timedelta(weeks=1), + timestamp=now_time - timedelta(weeks=1), open=RP2Decimal("1.1"), high=RP2Decimal("2.1"), low=RP2Decimal("0.6"), @@ -116,7 +124,7 @@ def test_retrieve_historical_bars( plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value) unoptimized_assets: Set[str] = {"A", "B"} optimization_candidates: Set[Vertex[str]] = {vertex_list["A"], vertex_list["B"], vertex_list["C"]} - week_start_date = datetime(2023, 1, 1) + week_start_date = historical_bars[MARKET_START].timestamp mocker.patch.object(plugin, "_AbstractCcxtPairConverterPlugin__exchange_markets", {TEST_EXCHANGE: TEST_MARKETS}) @@ -148,7 +156,7 @@ def find_historical_bars_side_effect( def test_generate_optimizations(self, historical_bars: Dict[str, HistoricalBar]) -> None: plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value) - week_start_date = datetime(2023, 1, 1) + week_start_date = historical_bars[MARKET_START].timestamp child_bars = {"A": {"B": [historical_bars[MARKET_START], historical_bars[ONE_WEEK_EARLIER]]}} @@ -193,3 +201,40 @@ def test_refine_and_finalize_optimizations(self) -> None: assert refined_optimizations[datetime(2023, 1, 4)]["A"]["C"] == 1.0 assert refined_optimizations[datetime(2023, 1, 4)]["D"]["F"] == 1.0 assert "E" not in refined_optimizations[datetime(2023, 1, 4)]["D"] + + def test_initialize_retry_count(self) -> None: + plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value) + + assert plugin._initialize_retry_count(_BINANCE, _ONE_HOUR) == _TIME_GRANULARITY.index(_ONE_HOUR) # pylint: disable=protected-access + assert plugin._initialize_retry_count(_COINBASE_PRO, _SIX_HOUR) == _TIME_GRANULARITY_DICT[_COINBASE_PRO].index( # pylint: disable=protected-access + _SIX_HOUR + ) + with pytest.raises(RP2ValueError): + # Binance does not support 6 hour granularity + assert plugin._initialize_retry_count(_BINANCE, _SIX_HOUR) # pylint: disable=protected-access + assert plugin._initialize_retry_count(_COINBASE_PRO, "invalid") # pylint: disable=protected-access + + def test_find_historical_bars_guard_clause(self, mocker: Any, historical_bars: Dict[str, HistoricalBar]) -> None: + plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value) + + mocker.patch.object(plugin, "_get_bundle_from_cache", return_value=[historical_bars[MARKET_START]]) + + bars = plugin.find_historical_bars("A", "B", datetime(2023, 1, 1), TEST_EXCHANGE, True) + + assert bars + assert len(bars) == 1 + assert bars[0] == historical_bars[MARKET_START] + + # To be enabled when _fetch_historical_bars is implemented + def disabled_test_find_historical_bars_add_to_cache(self, mocker: Any, historical_bars: Dict[str, HistoricalBar]) -> None: + plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value) + + mocker.patch.object(plugin, "_get_bundle_from_cache", return_value=historical_bars[ONE_WEEK_EARLIER]) + mocker.patch.object(plugin, "_fetch_historical_bars", return_value=[historical_bars[MARKET_START]]) # function that calls the API + + bars = plugin.find_historical_bars("A", "B", datetime(2023, 1, 1), TEST_EXCHANGE, True) + + assert bars + assert len(bars) == 2 + assert bars[0] == historical_bars[ONE_WEEK_EARLIER] + assert bars[1] == historical_bars[MARKET_START]