Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tighten up IterativeFinder async close behavior (DHT iterator continues after consumer breaks out of it) #3599

Merged
merged 9 commits into from
May 23, 2022
67 changes: 34 additions & 33 deletions lbry/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from prometheus_client import Gauge

from lbry.utils import resolve_host
from lbry.utils import aclosing, resolve_host
from lbry.dht import constants
from lbry.dht.peer import make_kademlia_peer
from lbry.dht.protocol.distance import Distance
Expand Down Expand Up @@ -217,9 +217,10 @@ async def peer_search(self, node_id: bytes, count=constants.K, max_results=const
shortlist: typing.Optional[typing.List['KademliaPeer']] = None
) -> typing.List['KademliaPeer']:
peers = []
async for iteration_peers in self.get_iterative_node_finder(
node_id, shortlist=shortlist, max_results=max_results):
peers.extend(iteration_peers)
async with aclosing(self.get_iterative_node_finder(
node_id, shortlist=shortlist, max_results=max_results)) as node_finder:
async for iteration_peers in node_finder:
peers.extend(iteration_peers)
distance = Distance(node_id)
peers.sort(key=lambda peer: distance(peer.node_id))
return peers[:count]
Expand All @@ -245,36 +246,36 @@ async def put_into_result_queue_after_pong(_peer):

# prioritize peers who reply to a dht ping first
# this minimizes attempting to make tcp connections that won't work later to dead or unreachable peers

async for results in self.get_iterative_value_finder(bytes.fromhex(blob_hash)):
to_put = []
for peer in results:
if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port:
continue
is_good = self.protocol.peer_manager.peer_is_good(peer)
if is_good:
# the peer has replied recently over UDP, it can probably be reached on the TCP port
to_put.append(peer)
elif is_good is None:
if not peer.udp_port:
# TODO: use the same port for TCP and UDP
# the udp port must be guessed
# default to the ports being the same. if the TCP port appears to be <=0.48.0 default,
# including on a network with several nodes, then assume the udp port is proportionately
# based on a starting port of 4444
udp_port_to_try = peer.tcp_port
if 3400 > peer.tcp_port > 3332:
udp_port_to_try = (peer.tcp_port - 3333) + 4444
self.loop.create_task(put_into_result_queue_after_pong(
make_kademlia_peer(peer.node_id, peer.address, udp_port_to_try, peer.tcp_port)
))
async with aclosing(self.get_iterative_value_finder(bytes.fromhex(blob_hash))) as value_finder:
async for results in value_finder:
to_put = []
for peer in results:
if peer.address == self.protocol.external_ip and self.protocol.peer_port == peer.tcp_port:
continue
is_good = self.protocol.peer_manager.peer_is_good(peer)
if is_good:
# the peer has replied recently over UDP, it can probably be reached on the TCP port
to_put.append(peer)
elif is_good is None:
if not peer.udp_port:
# TODO: use the same port for TCP and UDP
# the udp port must be guessed
# default to the ports being the same. if the TCP port appears to be <=0.48.0 default,
# including on a network with several nodes, then assume the udp port is proportionately
# based on a starting port of 4444
udp_port_to_try = peer.tcp_port
if 3400 > peer.tcp_port > 3332:
udp_port_to_try = (peer.tcp_port - 3333) + 4444
self.loop.create_task(put_into_result_queue_after_pong(
make_kademlia_peer(peer.node_id, peer.address, udp_port_to_try, peer.tcp_port)
))
else:
self.loop.create_task(put_into_result_queue_after_pong(peer))
else:
self.loop.create_task(put_into_result_queue_after_pong(peer))
else:
# the peer is known to be bad/unreachable, skip trying to connect to it over TCP
log.debug("skip bad peer %s:%i for %s", peer.address, peer.tcp_port, blob_hash)
if to_put:
result_queue.put_nowait(to_put)
# the peer is known to be bad/unreachable, skip trying to connect to it over TCP
log.debug("skip bad peer %s:%i for %s", peer.address, peer.tcp_port, blob_hash)
if to_put:
result_queue.put_nowait(to_put)

def accumulate_peers(self, search_queue: asyncio.Queue,
peer_queue: typing.Optional[asyncio.Queue] = None
Expand Down
50 changes: 38 additions & 12 deletions lbry/dht/protocol/iterative_find.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from itertools import chain
from collections import defaultdict, OrderedDict
from collections.abc import AsyncIterator
import typing
import logging
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -71,7 +72,7 @@ def get_shortlist(routing_table: 'TreeRoutingTable', key: bytes,
return shortlist or routing_table.find_close_peers(key)


class IterativeFinder:
class IterativeFinder(AsyncIterator):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
routing_table: 'TreeRoutingTable', protocol: 'KademliaProtocol', key: bytes,
max_results: typing.Optional[int] = constants.K,
Expand Down Expand Up @@ -151,7 +152,7 @@ async def _handle_probe_result(self, peer: 'KademliaPeer', response: FindRespons
log.warning("misbehaving peer %s:%i returned peer with reserved ip %s:%i", peer.address,
peer.udp_port, address, udp_port)
self.check_result_ready(response)
self._log_state()
self._log_state(reason="check result")

def _reset_closest(self, peer):
if peer in self.active:
Expand All @@ -163,12 +164,17 @@ async def _send_probe(self, peer: 'KademliaPeer'):
except asyncio.TimeoutError:
self._reset_closest(peer)
return
except asyncio.CancelledError:
log.debug("%s[%x] cancelled probe",
type(self).__name__, id(self))
raise
except ValueError as err:
log.warning(str(err))
self._reset_closest(peer)
return
except TransportNotConnected:
return self.aclose()
await self._aclose(reason="not connected")
return
except RemoteException:
self._reset_closest(peer)
return
Expand All @@ -182,7 +188,9 @@ def _search_round(self):
added = 0
for index, peer in enumerate(self.active.keys()):
if index == 0:
log.debug("closest to probe: %s", peer.node_id.hex()[:8])
log.debug("%s[%x] closest to probe: %s",
type(self).__name__, id(self),
peer.node_id.hex()[:8])
if peer in self.contacted:
continue
if len(self.running_probes) >= constants.ALPHA:
Expand All @@ -198,9 +206,13 @@ def _search_round(self):
continue
self._schedule_probe(peer)
added += 1
log.debug("running %d probes for key %s", len(self.running_probes), self.key.hex()[:8])
log.debug("%s[%x] running %d probes for key %s",
type(self).__name__, id(self),
len(self.running_probes), self.key.hex()[:8])
if not added and not self.running_probes:
log.debug("search for %s exhausted", self.key.hex()[:8])
log.debug("%s[%x] search for %s exhausted",
type(self).__name__, id(self),
self.key.hex()[:8])
self.search_exhausted()

def _schedule_probe(self, peer: 'KademliaPeer'):
Expand All @@ -216,9 +228,11 @@ def callback(_):
t.add_done_callback(callback)
self.running_probes[peer] = t

def _log_state(self):
log.debug("[%s] check result: %i active nodes %i contacted",
self.key.hex()[:8], len(self.active), len(self.contacted))
def _log_state(self, reason="?"):
log.debug("%s[%x] [%s] %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())

def __aiter__(self):
if self.running:
Expand All @@ -237,18 +251,30 @@ async def __anext__(self) -> typing.List['KademliaPeer']:
raise StopAsyncIteration
self.iteration_count += 1
return result
except (asyncio.CancelledError, StopAsyncIteration):
self.loop.call_soon(self.aclose)
except asyncio.CancelledError:
await self._aclose(reason="cancelled")
raise
except StopAsyncIteration:
await self._aclose(reason="no more results")
raise

def aclose(self):
async def _aclose(self, reason="?"):
log.debug("%s[%x] [%s] shutdown because %s: %i active nodes %i contacted %i produced %i queued",
type(self).__name__, id(self), self.key.hex()[:8],
reason, len(self.active), len(self.contacted),
self.iteration_count, self.iteration_queue.qsize())
self.running = False
self.iteration_queue.put_nowait(None)
for task in chain(self.tasks, self.running_probes.values()):
task.cancel()
self.tasks.clear()
self.running_probes.clear()

async def aclose(self):
if self.running:
await self._aclose(reason="aclose")
log.debug("%s[%x] [%s] async close completed",
type(self).__name__, id(self), self.key.hex()[:8])

class IterativeNodeFinder(IterativeFinder):
def __init__(self, loop: asyncio.AbstractEventLoop, peer_manager: 'PeerManager',
Expand Down
10 changes: 10 additions & 0 deletions lbry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def get_sd_hash(stream_info):
def json_dumps_pretty(obj, **kwargs):
return json.dumps(obj, sort_keys=True, indent=2, separators=(',', ': '), **kwargs)

try:
# the standard contextlib.aclosing() is available in 3.10+
from contextlib import aclosing # pylint: disable=unused-import
except ImportError:
@contextlib.asynccontextmanager
async def aclosing(thing):
try:
yield thing
finally:
await thing.aclose()

def async_timed_cache(duration: int):
def wrapper(func):
Expand Down