Skip to content
This repository has been archived by the owner on Jul 1, 2021. It is now read-only.

Commit

Permalink
Refactor code to be more DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
cburgdorf committed Feb 14, 2020
1 parent f5c9e59 commit dcf264b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 43 deletions.
43 changes: 19 additions & 24 deletions trinity/protocol/common/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from eth_typing import BlockNumber, Hash32

from p2p.abc import ConnectionAPI
from p2p.logic import Application
from p2p.qualifiers import HasProtocol

Expand All @@ -17,6 +18,21 @@
LESProtocolV2) | HasProtocol(LESProtocolV1)


def choose_eth_or_les_api(
connection: ConnectionAPI) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]:

if connection.has_protocol(ETHProtocol):
return connection.get_logic(ETHAPI.name, ETHAPI)
elif connection.has_protocol(ETHProtocolV63):
return connection.get_logic(ETHV63API.name, ETHV63API)
elif connection.has_protocol(LESProtocolV2):
return connection.get_logic(LESV2API.name, LESV2API)
elif connection.has_protocol(LESProtocolV1):
return connection.get_logic(LESV1API.name, LESV1API)
else:
raise Exception("Unreachable code path")


class ChainInfo(Application, ChainInfoAPI):
name = 'eth1-chain-info'

Expand All @@ -31,16 +47,7 @@ def genesis_hash(self) -> Hash32:
return self._get_logic().genesis_hash

def _get_logic(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]:
if self.connection.has_protocol(ETHProtocol):
return self.connection.get_logic(ETHAPI.name, ETHAPI)
elif self.connection.has_protocol(ETHProtocolV63):
return self.connection.get_logic(ETHV63API.name, ETHV63API)
elif self.connection.has_protocol(LESProtocolV2):
return self.connection.get_logic(LESV2API.name, LESV2API)
elif self.connection.has_protocol(LESProtocolV1):
return self.connection.get_logic(LESV1API.name, LESV1API)
else:
raise Exception("Unreachable code path")
return choose_eth_or_les_api(self.connection)


class HeadInfo(Application, HeadInfoAPI):
Expand All @@ -50,20 +57,8 @@ class HeadInfo(Application, HeadInfoAPI):

@cached_property
def _tracker(self) -> HeadInfoAPI:
if self.connection.has_protocol(ETHProtocol):
eth_logic = self.connection.get_logic(ETHAPI.name, ETHAPI)
return eth_logic.head_info
elif self.connection.has_protocol(ETHProtocolV63):
eth_v63_logic = self.connection.get_logic(ETHV63API.name, ETHV63API)
return eth_v63_logic.head_info
elif self.connection.has_protocol(LESProtocolV2):
les_v2_logic = self.connection.get_logic(LESV2API.name, LESV2API)
return les_v2_logic.head_info
elif self.connection.has_protocol(LESProtocolV1):
les_v1_logic = self.connection.get_logic(LESV1API.name, LESV1API)
return les_v1_logic.head_info
else:
raise Exception("Unreachable code path")
api = choose_eth_or_les_api(self.connection)
return api.head_info

@property
def head_td(self) -> int:
Expand Down
21 changes: 2 additions & 19 deletions trinity/protocol/common/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,14 @@
from trinity.constants import TO_NETWORKING_BROADCAST_CONFIG
from trinity.exceptions import BaseForkIDValidationError
from trinity.protocol.common.abc import ChainInfoAPI, HeadInfoAPI
from trinity.protocol.common.api import ChainInfo, HeadInfo
from trinity.protocol.common.api import ChainInfo, HeadInfo, choose_eth_or_les_api
from trinity.protocol.eth.api import ETHV63API, ETHAPI
from trinity.protocol.eth.forkid import (
extract_fork_blocks,
extract_forkid,
validate_forkid,
)
from trinity.protocol.eth.proto import ETHProtocol, ETHProtocolV63
from trinity.protocol.les.api import LESV1API, LESV2API
from trinity.protocol.les.proto import LESProtocolV1, LESProtocolV2

from trinity.components.builtin.network_db.connection.tracker import ConnectionTrackerClient
from trinity.components.builtin.network_db.eth1_peer_db.tracker import (
Expand All @@ -96,22 +94,7 @@ class BaseChainPeer(BasePeer):

@cached_property
def chain_api(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]:
if self.connection.has_logic(ETHAPI.name):
if self.connection.has_protocol(ETHProtocol):
return self.connection.get_logic(ETHAPI.name, ETHAPI)
elif self.connection.has_protocol(ETHProtocolV63):
return self.connection.get_logic(ETHV63API.name, ETHV63API)
else:
raise Exception("Should be unreachable")
elif self.connection.has_logic(LESV1API.name):
if self.connection.has_protocol(LESProtocolV2):
return self.connection.get_logic(LESV2API.name, LESV2API)
elif self.connection.has_protocol(LESProtocolV1):
return self.connection.get_logic(LESV1API.name, LESV1API)
else:
raise Exception("Should be unreachable")
else:
raise Exception("Should be unreachable")
return choose_eth_or_les_api(self.connection)

@cached_property
def head_info(self) -> HeadInfoAPI:
Expand Down

0 comments on commit dcf264b

Please sign in to comment.