Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor find historical bars #270

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions src/dali/abstract_ccxt_pair_converter_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@
DAYS_IN_WEEK: int = 7
MANY_YEARS_IN_THE_FUTURE: relativedelta = relativedelta(years=100)


class AssetPairAndHistoricalPrice(NamedTuple):
from_asset: str
to_asset: str
Expand Down Expand Up @@ -513,16 +514,23 @@ def find_historical_bar(self, from_asset: str, to_asset: str, timestamp: datetim
def find_historical_bars(
self, from_asset: str, to_asset: str, timestamp: datetime, exchange: str, all_bars: bool = False, timespan: str = _MINUTE
) -> Optional[List[HistoricalBar]]:

# Guard clause to pull from cache
if all_bars:
cached_bundle = self._get_bundle_from_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange))
if cached_bundle:
# If the last bar in the bundle is within the last week, return the bundle
if (datetime.now(timezone.utc) - cached_bundle[-1].timestamp).total_seconds() <= _TIME_GRANULARITY_STRING_TO_SECONDS[_ONE_WEEK]:
return cached_bundle
# If the last bar in the bundle is older than a week, we need to start from the next millisecond
# We will pull the rest later and add to this bundle
timestamp = cached_bundle[-1].timestamp + timedelta(milliseconds=1)

# Stage 1 Initialization
result: List[HistoricalBar] = []
retry_count: int = 0
self.__transaction_count += 1
if timespan in _TIME_GRANULARITY_SET:
if exchange in _NONSTANDARD_GRANULARITY_EXCHANGE_SET:
retry_count = _TIME_GRANULARITY_DICT[exchange].index(timespan)
else:
retry_count = _TIME_GRANULARITY.index(timespan)
else:
raise RP2ValueError("Internal error: Invalid time span passed to find_historical_bars.")
retry_count: int = self._initialize_retry_count(exchange, timespan)

current_exchange: Any = self.__exchanges[exchange]
ms_timestamp: int = int(timestamp.timestamp() * _MS_IN_SECOND)
csv_pricing: Any = self.__csv_pricing_dict.get(exchange)
Expand Down Expand Up @@ -566,18 +574,6 @@ def find_historical_bars(

within_last_week: bool = False

# Get bundles of bars if they exist, saving us from making a call to the API
if all_bars:
cached_bundle: Optional[List[HistoricalBar]] = self._get_bundle_from_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange))
if cached_bundle:
result.extend(cached_bundle)
timestamp = cached_bundle[-1].timestamp + timedelta(milliseconds=1)
ms_timestamp = int(timestamp.timestamp() * _MS_IN_SECOND)

# If the bundle of bars is within the last week, we don't need to pull new optimization data.
if result and (datetime.now(timezone.utc) - result[-1].timestamp).total_seconds() > _TIME_GRANULARITY_STRING_TO_SECONDS[_ONE_WEEK]:
within_last_week = True

while (retry_count < len(_TIME_GRANULARITY_DICT.get(exchange, _TIME_GRANULARITY))) and not within_last_week:
timeframe: str = _TIME_GRANULARITY_DICT.get(exchange, _TIME_GRANULARITY)[retry_count]
request_count: int = 0
Expand Down Expand Up @@ -718,6 +714,8 @@ def find_historical_bars(
)
)
elif all_bars:
if cached_bundle:
result = cached_bundle + result
self._add_bundle_to_cache(AssetPairAndTimestamp(timestamp, from_asset, to_asset, exchange), result)
break # If historical_data is empty we have hit the end of records and need to return
else:
Expand All @@ -728,6 +726,17 @@ def find_historical_bars(

return result

def _initialize_retry_count(self, exchange: str, timespan: str) -> int:
if timespan not in _TIME_GRANULARITY_SET:
raise RP2ValueError(f"Internal error: Invalid time span '{timespan}' passed to find_historical_bars.")

granularity = _TIME_GRANULARITY_DICT[exchange] if exchange in _NONSTANDARD_GRANULARITY_EXCHANGE_SET else _TIME_GRANULARITY

if timespan not in granularity:
raise RP2ValueError(f"Internal error: Time span '{timespan}' is not valid for exchange '{exchange}'.")

return granularity.index(timespan)

def _add_alternative_markets(self, graph: MappedGraph[str], current_markets: Dict[str, List[str]]) -> None:
for base_asset, quote_asset in _ALT_MARKET_BY_BASE_DICT.items():
alt_market = base_asset + quote_asset
Expand Down
55 changes: 50 additions & 5 deletions tests/test_abstract_ccxt_pair_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Set

import pytest
from prezzemolo.vertex import Vertex
from rp2.rp2_decimal import RP2Decimal
from rp2.rp2_error import RP2ValueError

from dali.abstract_ccxt_pair_converter_plugin import (
_BINANCE,
_COINBASE_PRO,
_ONE_HOUR,
_SIX_HOUR,
_TIME_GRANULARITY,
_TIME_GRANULARITY_DICT,
MARKET_PADDING_IN_WEEKS,
AbstractCcxtPairConverterPlugin,
)
Expand Down Expand Up @@ -74,10 +81,11 @@ def unoptimized_graph(self, vertex_list: Dict[str, Vertex[str]]) -> MappedGraph[

@pytest.fixture
def historical_bars(self) -> Dict[str, HistoricalBar]:
now_time = datetime.now(timezone.utc)
return {
MARKET_START: HistoricalBar(
duration=timedelta(weeks=1),
timestamp=datetime(2023, 1, 1),
timestamp=now_time,
open=RP2Decimal("1.0"),
high=RP2Decimal("2.0"),
low=RP2Decimal("0.5"),
Expand All @@ -86,7 +94,7 @@ def historical_bars(self) -> Dict[str, HistoricalBar]:
),
ONE_WEEK_EARLIER: HistoricalBar(
duration=timedelta(weeks=1),
timestamp=datetime(2023, 1, 1) - timedelta(weeks=1),
timestamp=now_time - timedelta(weeks=1),
open=RP2Decimal("1.1"),
high=RP2Decimal("2.1"),
low=RP2Decimal("0.6"),
Expand Down Expand Up @@ -116,7 +124,7 @@ def test_retrieve_historical_bars(
plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value)
unoptimized_assets: Set[str] = {"A", "B"}
optimization_candidates: Set[Vertex[str]] = {vertex_list["A"], vertex_list["B"], vertex_list["C"]}
week_start_date = datetime(2023, 1, 1)
week_start_date = historical_bars[MARKET_START].timestamp

mocker.patch.object(plugin, "_AbstractCcxtPairConverterPlugin__exchange_markets", {TEST_EXCHANGE: TEST_MARKETS})

Expand Down Expand Up @@ -148,7 +156,7 @@ def find_historical_bars_side_effect(

def test_generate_optimizations(self, historical_bars: Dict[str, HistoricalBar]) -> None:
plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value)
week_start_date = datetime(2023, 1, 1)
week_start_date = historical_bars[MARKET_START].timestamp

child_bars = {"A": {"B": [historical_bars[MARKET_START], historical_bars[ONE_WEEK_EARLIER]]}}

Expand Down Expand Up @@ -193,3 +201,40 @@ def test_refine_and_finalize_optimizations(self) -> None:
assert refined_optimizations[datetime(2023, 1, 4)]["A"]["C"] == 1.0
assert refined_optimizations[datetime(2023, 1, 4)]["D"]["F"] == 1.0
assert "E" not in refined_optimizations[datetime(2023, 1, 4)]["D"]

def test_initialize_retry_count(self) -> None:
plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value)

assert plugin._initialize_retry_count(_BINANCE, _ONE_HOUR) == _TIME_GRANULARITY.index(_ONE_HOUR) # pylint: disable=protected-access
assert plugin._initialize_retry_count(_COINBASE_PRO, _SIX_HOUR) == _TIME_GRANULARITY_DICT[_COINBASE_PRO].index( # pylint: disable=protected-access
_SIX_HOUR
)
with pytest.raises(RP2ValueError):
# Binance does not support 6 hour granularity
assert plugin._initialize_retry_count(_BINANCE, _SIX_HOUR) # pylint: disable=protected-access
assert plugin._initialize_retry_count(_COINBASE_PRO, "invalid") # pylint: disable=protected-access

def test_find_historical_bars_guard_clause(self, mocker: Any, historical_bars: Dict[str, HistoricalBar]) -> None:
plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value)

mocker.patch.object(plugin, "_get_bundle_from_cache", return_value=[historical_bars[MARKET_START]])

bars = plugin.find_historical_bars("A", "B", datetime(2023, 1, 1), TEST_EXCHANGE, True)

assert bars
assert len(bars) == 1
assert bars[0] == historical_bars[MARKET_START]

# To be enabled when _fetch_historical_bars is implemented
def disabled_test_find_historical_bars_add_to_cache(self, mocker: Any, historical_bars: Dict[str, HistoricalBar]) -> None:
plugin = MockAbstractCcxtPairConverterPlugin(Keyword.HISTORICAL_PRICE_HIGH.value)

mocker.patch.object(plugin, "_get_bundle_from_cache", return_value=historical_bars[ONE_WEEK_EARLIER])
mocker.patch.object(plugin, "_fetch_historical_bars", return_value=[historical_bars[MARKET_START]]) # function that calls the API

bars = plugin.find_historical_bars("A", "B", datetime(2023, 1, 1), TEST_EXCHANGE, True)

assert bars
assert len(bars) == 2
assert bars[0] == historical_bars[ONE_WEEK_EARLIER]
assert bars[1] == historical_bars[MARKET_START]
Loading