Skip to content

Commit

Permalink
Revert SubscriptionManager
Browse files Browse the repository at this point in the history
  • Loading branch information
droserasprout committed Aug 5, 2021
1 parent 2468746 commit 9851271
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 95 deletions.
60 changes: 1 addition & 59 deletions src/dipdup/datasources/datasource.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from abc import abstractmethod
from collections import defaultdict
from copy import copy
from enum import Enum
from functools import partial
import logging
from typing import Awaitable, DefaultDict, List, Optional, Protocol, Set
from typing import Awaitable, List, Optional, Protocol

from pydantic.dataclasses import dataclass
from pydantic.fields import Field
from pyee import AsyncIOEventEmitter # type: ignore

from dipdup.config import HTTPConfig
from dipdup.http import HTTPGateway
from dipdup.models import BigMapData, HeadBlockData, OperationData


# NOTE: Since there's no other index datasource
_logger = logging.getLogger('dipdup.tzkt')


class EventType(Enum):
operations = 'operatitions'
big_maps = 'big_maps'
Expand Down Expand Up @@ -88,51 +78,3 @@ def emit_rollback(self, from_level: int, to_level: int) -> None:

def emit_head(self, block: HeadBlockData) -> None:
super().emit(EventType.head, datasource=self, block=block)


@dataclass
class Subscriptions:
address_transactions: Set[str] = Field(default_factory=set)
originations: bool = False
head: bool = False
big_maps: DefaultDict[str, Set[str]] = Field(default_factory=partial(defaultdict, set))

def get_pending(self, active_subscriptions: 'Subscriptions') -> 'Subscriptions':
return Subscriptions(
address_transactions=self.address_transactions.difference(active_subscriptions.address_transactions),
originations=not active_subscriptions.originations,
head=not active_subscriptions.head,
big_maps=defaultdict(set, {k: self.big_maps[k] for k in set(self.big_maps) - set(active_subscriptions.big_maps)}),
)

class SubscriptionManager:
def __init__(self) -> None:
self._subscriptions: Subscriptions = Subscriptions()
self._active_subscriptions: Subscriptions = Subscriptions()

def status(self, pending: bool = False) -> str:
subs = self.get_pending() if pending else self._active_subscriptions
big_maps_len = sum([len(v) for v in subs.big_maps.values()])
return f'{len(subs.address_transactions)} contracts, {int(subs.originations)} originations, {int(subs.head)} head, {big_maps_len} big maps'

def add_address_transaction_subscription(self, address: str) -> None:
self._subscriptions.address_transactions.add(address)

def add_origination_subscription(self) -> None:
self._subscriptions.originations = True

def add_head_subscription(self) -> None:
self._subscriptions.head = True

def add_big_map_subscription(self, address: str, paths: Set[str]) -> None:
self._subscriptions.big_maps[address] = self._subscriptions.big_maps[address] | paths

def get_pending(self) -> Subscriptions:
pending_subscriptions = self._subscriptions.get_pending(self._active_subscriptions)
return pending_subscriptions

def commit(self) -> None:
self._active_subscriptions = copy(self._subscriptions)

def reset(self) -> None:
self._active_subscriptions = Subscriptions()
51 changes: 19 additions & 32 deletions src/dipdup/datasources/tzkt/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
OperationHandlerOriginationPatternConfig,
OperationIndexConfig,
)
from dipdup.datasources.datasource import IndexDatasource, SubscriptionManager
from dipdup.datasources.datasource import IndexDatasource
from dipdup.datasources.tzkt.enums import TzktMessageType
from dipdup.models import BigMapAction, BigMapData, BlockData, HeadBlockData, OperationData
from dipdup.utils import groupby, split_by_chunks
Expand Down Expand Up @@ -292,13 +292,13 @@ def __init__(
self,
url: str,
http_config: Optional[HTTPConfig] = None,
realtime: bool = True,
) -> None:
super().__init__(url, http_config)
self._logger = logging.getLogger('dipdup.tzkt')
self._subscriptions: SubscriptionManager = SubscriptionManager()

self._realtime: bool = realtime
self._transaction_subscriptions: Set[str] = set()
self._origination_subscriptions: bool = False
self._big_map_subscriptions: Dict[str, Set[str]] = {}
self._client: Optional[BaseHubConnection] = None

self._block: Optional[HeadBlockData] = None
Expand Down Expand Up @@ -502,22 +502,24 @@ async def add_index(self, index_config: IndexConfigTemplateT) -> None:

if isinstance(index_config, OperationIndexConfig):
for contract_config in index_config.contracts or []:
self._subscriptions.add_address_transaction_subscription(cast(ContractConfig, contract_config).address)

self._transaction_subscriptions.add(cast(ContractConfig, contract_config).address)
for handler_config in index_config.handlers:
for pattern_config in handler_config.pattern:
if isinstance(pattern_config, OperationHandlerOriginationPatternConfig):
self._subscriptions.add_origination_subscription()
self._origination_subscriptions = True

elif isinstance(index_config, BigMapIndexConfig):
for big_map_handler_config in index_config.handlers:
address, path = big_map_handler_config.contract_config.address, big_map_handler_config.path
self._subscriptions.add_big_map_subscription(address, set(path))
if address not in self._big_map_subscriptions:
self._big_map_subscriptions[address] = set()
if path not in self._big_map_subscriptions[address]:
self._big_map_subscriptions[address].add(path)

else:
raise NotImplementedError(f'Index kind `{index_config.kind}` is not supported')

await self.subscribe()
await self._on_connect()

def _get_client(self) -> BaseHubConnection:
"""Create SignalR client, register message callbacks"""
Expand Down Expand Up @@ -554,34 +556,19 @@ async def run(self) -> None:

async def _on_connect(self) -> None:
"""Subscribe to all required channels on established WS connection"""
self._logger.info('Connected to server')
await self.subscribe()

async def _on_disconnect(self) -> None:
self._logger.info('Disconnected from server')
self._subscriptions.reset()

async def subscribe(self) -> None:
"""Subscribe to all required channels"""
if not self._realtime:
if self._get_client().transport.state != ConnectionState.connected:
return

pending_subscriptions = self._subscriptions.get_pending()
self._logger.info('Subscribing to channels')
self._logger.info('Active: %s', self._subscriptions.status(False))
self._logger.info('Pending: %s', self._subscriptions.status(True))

for address in pending_subscriptions.address_transactions:
await self._subscribe_to_address_transactions(address)
if pending_subscriptions.originations:
self._logger.info('Connected to server')
await self._subscribe_to_head()
for address in self._transaction_subscriptions:
await self._subscribe_to_transactions(address)
# NOTE: All originations are passed to matcher
if self._origination_subscriptions:
await self._subscribe_to_originations()
if pending_subscriptions.head:
await self._subscribe_to_head()
for address, paths in pending_subscriptions.big_maps.items():
for address, paths in self._big_map_subscriptions.items():
await self._subscribe_to_big_maps(address, paths)

self._subscriptions.commit()

# NOTE: Pay attention: this is not a pyee callback
def _on_error(self, message: CompletionMessage) -> NoReturn:
"""Raise exception from WS server's error message"""
Expand Down
3 changes: 1 addition & 2 deletions src/dipdup/dipdup.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def __init__(self, config: DipDupConfig) -> None:

async def init(self) -> None:
"""Create new or update existing dipdup project"""
await self._create_datasources(realtime=False)
await self._create_datasources()

async with AsyncExitStack() as stack:
for datasource in self._datasources.values():
Expand Down Expand Up @@ -271,7 +271,6 @@ async def _create_datasources(self, realtime: bool = True) -> None:
datasource = TzktDatasource(
url=datasource_config.url,
http_config=datasource_config.http,
realtime=realtime,
)
elif isinstance(datasource_config, BcdDatasourceConfig):
datasource = BcdDatasource(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_rollback.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def test_rollback_ok(self):
config.database.path = ':memory:'

datasource_name, datasource_config = list(config.datasources.items())[0]
datasource = TzktDatasource('test', realtime=False)
datasource = TzktDatasource('test')
dipdup = DipDup(config)
dipdup._datasources[datasource_name] = datasource
dipdup._datasources_by_config[datasource_config] = datasource
Expand All @@ -121,7 +121,7 @@ async def test_rollback_fail(self):
config.database.path = ':memory:'

datasource_name, datasource_config = list(config.datasources.items())[0]
datasource = TzktDatasource('test', realtime=False)
datasource = TzktDatasource('test')
dipdup = DipDup(config)
dipdup._datasources[datasource_name] = datasource
dipdup._datasources_by_config[datasource_config] = datasource
Expand Down

0 comments on commit 9851271

Please sign in to comment.