diff --git a/trinity/protocol/common/api.py b/trinity/protocol/common/api.py index 8a97bcfcba..c215921f50 100644 --- a/trinity/protocol/common/api.py +++ b/trinity/protocol/common/api.py @@ -3,6 +3,7 @@ from eth_typing import BlockNumber, Hash32 +from p2p.abc import ConnectionAPI from p2p.logic import Application from p2p.qualifiers import HasProtocol @@ -17,6 +18,21 @@ LESProtocolV2) | HasProtocol(LESProtocolV1) +def choose_eth_or_les_api( + connection: ConnectionAPI) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: + + if connection.has_protocol(ETHProtocol): + return connection.get_logic(ETHAPI.name, ETHAPI) + elif connection.has_protocol(ETHProtocolV63): + return connection.get_logic(ETHV63API.name, ETHV63API) + elif connection.has_protocol(LESProtocolV2): + return connection.get_logic(LESV2API.name, LESV2API) + elif connection.has_protocol(LESProtocolV1): + return connection.get_logic(LESV1API.name, LESV1API) + else: + raise Exception("Unreachable code path") + + class ChainInfo(Application, ChainInfoAPI): name = 'eth1-chain-info' @@ -31,16 +47,7 @@ def genesis_hash(self) -> Hash32: return self._get_logic().genesis_hash def _get_logic(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: - if self.connection.has_protocol(ETHProtocol): - return self.connection.get_logic(ETHAPI.name, ETHAPI) - elif self.connection.has_protocol(ETHProtocolV63): - return self.connection.get_logic(ETHV63API.name, ETHV63API) - elif self.connection.has_protocol(LESProtocolV2): - return self.connection.get_logic(LESV2API.name, LESV2API) - elif self.connection.has_protocol(LESProtocolV1): - return self.connection.get_logic(LESV1API.name, LESV1API) - else: - raise Exception("Unreachable code path") + return choose_eth_or_les_api(self.connection) class HeadInfo(Application, HeadInfoAPI): @@ -50,20 +57,8 @@ class HeadInfo(Application, HeadInfoAPI): @cached_property def _tracker(self) -> HeadInfoAPI: - if self.connection.has_protocol(ETHProtocol): - eth_logic = self.connection.get_logic(ETHAPI.name, ETHAPI) - return eth_logic.head_info - elif self.connection.has_protocol(ETHProtocolV63): - eth_v63_logic = self.connection.get_logic(ETHV63API.name, ETHV63API) - return eth_v63_logic.head_info - elif self.connection.has_protocol(LESProtocolV2): - les_v2_logic = self.connection.get_logic(LESV2API.name, LESV2API) - return les_v2_logic.head_info - elif self.connection.has_protocol(LESProtocolV1): - les_v1_logic = self.connection.get_logic(LESV1API.name, LESV1API) - return les_v1_logic.head_info - else: - raise Exception("Unreachable code path") + api = choose_eth_or_les_api(self.connection) + return api.head_info @property def head_td(self) -> int: diff --git a/trinity/protocol/common/peer.py b/trinity/protocol/common/peer.py index 21936ebc9e..195613b451 100644 --- a/trinity/protocol/common/peer.py +++ b/trinity/protocol/common/peer.py @@ -62,16 +62,14 @@ from trinity.constants import TO_NETWORKING_BROADCAST_CONFIG from trinity.exceptions import BaseForkIDValidationError from trinity.protocol.common.abc import ChainInfoAPI, HeadInfoAPI -from trinity.protocol.common.api import ChainInfo, HeadInfo +from trinity.protocol.common.api import ChainInfo, HeadInfo, choose_eth_or_les_api from trinity.protocol.eth.api import ETHV63API, ETHAPI from trinity.protocol.eth.forkid import ( extract_fork_blocks, extract_forkid, validate_forkid, ) -from trinity.protocol.eth.proto import ETHProtocol, ETHProtocolV63 from trinity.protocol.les.api import LESV1API, LESV2API -from trinity.protocol.les.proto import LESProtocolV1, LESProtocolV2 from trinity.components.builtin.network_db.connection.tracker import ConnectionTrackerClient from trinity.components.builtin.network_db.eth1_peer_db.tracker import ( @@ -96,22 +94,7 @@ class BaseChainPeer(BasePeer): @cached_property def chain_api(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: - if self.connection.has_logic(ETHAPI.name): - if self.connection.has_protocol(ETHProtocol): - return self.connection.get_logic(ETHAPI.name, ETHAPI) - elif self.connection.has_protocol(ETHProtocolV63): - return self.connection.get_logic(ETHV63API.name, ETHV63API) - else: - raise Exception("Should be unreachable") - elif self.connection.has_logic(LESV1API.name): - if self.connection.has_protocol(LESProtocolV2): - return self.connection.get_logic(LESV2API.name, LESV2API) - elif self.connection.has_protocol(LESProtocolV1): - return self.connection.get_logic(LESV1API.name, LESV1API) - else: - raise Exception("Should be unreachable") - else: - raise Exception("Should be unreachable") + return choose_eth_or_les_api(self.connection) @cached_property def head_info(self) -> HeadInfoAPI: