Skip to content

Commit

Permalink
Rework NoBsWs to avoid agen/trio incompatibility
Browse files Browse the repository at this point in the history
`trio`'s internals don't allow for async generator (and thus by
consequence dynamic reset of async exit stacks containing `@acm`s)
interleaving since doing so corrupts the cancel-scope stack. See details
in:
- python-trio/trio#638
- https://trio-util.readthedocs.io/en/latest/#trio_util.trio_async_generator

We originally tried to address this using
`@trio_util.trio_async_generator` in backend streaming code but for
whatever reason stopped working recently (at least for me) and it's more
or less implemented the same way as this patch but with more layers and
an extra dep. I also don't want us to have to address this problem again
if/when that lib isn't able to keep up to date with wtv `trio` is
doing..

So instead this is a complete rewrite of the conc design of our
auto-reconnect ws API to move all reset logic and msg relay into a bg
task which is respawned on reset-requiring events: user spec-ed msg recv
latency, network errors, roaming events.

Deatz:
- drop all usage of `AsyncExitStack` and no longer require client code
  to (hackily) call `NoBsWs._connect()` on msg latency conditions,
  intead this is all done behind the scenes and the user can instead
  pass in a `msg_recv_timeout: float`.
- massively simplify impl of `NoBsWs` and move all reset logic into a
  new `_reconnect_forever()` task.
- offer use of `reset_after: int` a count value that determines how many
  `msg_recv_timeout` events are allowed to occur before reconnecting the
  entire ws from scratch again.
  • Loading branch information
goodboy committed Apr 21, 2023
1 parent ac4a8a3 commit 2e2f49e
Showing 1 changed file with 231 additions and 81 deletions.
312 changes: 231 additions & 81 deletions piker/data/_web_bs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for piker0)
# Copyright (C) Tyler Goodlet (in stewardship for pikers)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
Expand All @@ -18,23 +18,29 @@
ToOlS fOr CoPInG wITh "tHE wEB" protocols.
"""
from __future__ import annotations
from contextlib import (
asynccontextmanager,
AsyncExitStack,
asynccontextmanager as acm,
)
from itertools import count
from functools import partial
from types import ModuleType
from typing import (
Any,
Optional,
Callable,
AsyncContextManager,
AsyncGenerator,
Iterable,
)
import json

import trio
import trio_websocket
from trio_typing import TaskStatus
from trio_websocket import (
WebSocketConnection,
open_websocket_url,
)
from wsproto.utilities import LocalProtocolError
from trio_websocket._impl import (
ConnectionClosed,
Expand All @@ -52,9 +58,15 @@ class NoBsWs:
'''
Make ``trio_websocket`` sockets stay up no matter the bs.
You can provide a ``fixture`` async-context-manager which will be
enter/exitted around each reconnect operation.
A shim interface that allows client code to stream from some
``WebSocketConnection`` but where any connectivy bs is handled
automatcially and entirely in the background.
NOTE: this type should never be created directly but instead is
provided via the ``open_autorecon_ws()`` factor below.
'''
# apparently we can QoS for all sorts of reasons..so catch em.
recon_errors = (
ConnectionClosed,
DisconnectionTimeout,
Expand All @@ -67,115 +79,253 @@ class NoBsWs:
def __init__(
self,
url: str,
stack: AsyncExitStack,
fixture: Optional[Callable] = None,
rxchan: trio.MemoryReceiveChannel,
msg_recv_timeout: float,

serializer: ModuleType = json
):
self.url = url
self.fixture = fixture
self._stack = stack
self._ws: 'WebSocketConnection' = None # noqa

# TODO: is there some method we can call
# on the underlying `._ws` to get this?
self._connected: bool = False

async def _connect(
self,
tries: int = 1000,
) -> None:

self._connected = False
while True:
try:
await self._stack.aclose()
except self.recon_errors:
await trio.sleep(0.5)
else:
break

last_err = None
for i in range(tries):
try:
self._ws = await self._stack.enter_async_context(
trio_websocket.open_websocket_url(self.url)
)
self._rx = rxchan
self._timeout = msg_recv_timeout

if self.fixture is not None:
# rerun user code fixture
ret = await self._stack.enter_async_context(
self.fixture(self)
)
# signaling between caller and relay task which determines when
# socket is connected (and subscribed).
self._connected: trio.Event = trio.Event()

assert ret is None
# dynamically reset by the bg relay task
self._ws: WebSocketConnection | None = None
self._cs: trio.CancelScope | None = None

log.info(f'Connection success: {self.url}')
# interchange codec methods
# TODO: obviously the method API here may be different
# for another interchange format..
self._dumps: Callable = serializer.dumps
self._loads: Callable = serializer.loads

self._connected = True
return self._ws
def connected(self) -> bool:
return self._connected.is_set()

except self.recon_errors as err:
last_err = err
log.error(
f'{self} connection bail with '
f'{type(err)}...retry attempt {i}'
)
await trio.sleep(0.5)
self._connected = False
continue
else:
log.exception('ws connection fail...')
raise last_err
async def reset(self) -> None:
'''
Reset the underlying ws connection by cancelling
the bg relay task and waiting for it to signal
a new connection.
def connected(self) -> bool:
return self._connected
'''
self._connected = trio.Event()
self._cs.cancel()
await self._connected.wait()

async def send_msg(
self,
data: Any,
) -> None:
while True:
try:
return await self._ws.send_message(json.dumps(data))
msg: Any = self._dumps(data)
return await self._ws.send_message(msg)
except self.recon_errors:
await self._connect()
await self.reset()

async def recv_msg(
self,
) -> Any:
while True:
try:
return json.loads(await self._ws.get_message())
except self.recon_errors:
await self._connect()
async def recv_msg(self) -> Any:
msg: Any = await self._rx.receive()
data = self._loads(msg)
return data

def __aiter__(self):
return self

async def __anext__(self):
return await self.recv_msg()

def set_recv_timeout(
self,
timeout: float,
) -> None:
self._timeout = timeout


async def _reconnect_forever(
url: str,
snd: trio.MemorySendChannel,
nobsws: NoBsWs,
reset_after: int, # msg recv timeout before reset attempt

fixture: AsyncContextManager | None = None,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED,

@asynccontextmanager
) -> None:

async def proxy_msgs(
ws: WebSocketConnection,
pcs: trio.CancelScope, # parent cancel scope
):
'''
Receive (under `timeout` deadline) all msgs from from underlying
websocket and relay them to (calling) parent task via ``trio``
mem chan.
'''
# after so many msg recv timeouts, reset the connection
timeouts: int = 0

while True:
with trio.move_on_after(
# can be dynamically changed by user code
nobsws._timeout,
) as cs:
try:
msg: Any = await ws.get_message()
await snd.send(msg)
except nobsws.recon_errors:
log.exception(
f'{url} connection bail with:'
)
await trio.sleep(0.5)
pcs.cancel()

# go back to reonnect loop in parent task
return

if cs.cancelled_caught:
timeouts += 1
if timeouts > reset_after:
log.error(
'WS feed seems down and slow af? .. resetting\n'
)
pcs.cancel()

# go back to reonnect loop in parent task
return

async def open_fixture(
fixture: AsyncContextManager,
nobsws: NoBsWs,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
):
'''
Open user provided `@acm` and sleep until any connection
reset occurs.
'''
async with fixture(nobsws) as ret:
assert ret is None
task_status.started()
await trio.sleep_forever()

# last_err = None
nobsws._connected = trio.Event()
task_status.started()

while not snd._closed:
log.info(f'{url} trying (RE)CONNECT')

async with trio.open_nursery() as n:
cs = nobsws._cs = n.cancel_scope
ws: WebSocketConnection
async with open_websocket_url(url) as ws:
nobsws._ws = ws
log.info(f'Connection success: {url}')

# begin relay loop to forward msgs
n.start_soon(
proxy_msgs,
ws,
cs,
)

if fixture is not None:
log.info(f'Entering fixture: {fixture}')

# TODO: should we return an explicit sub-cs
# from this fixture task?
await n.start(
open_fixture,
fixture,
nobsws,
)

# indicate to wrapper / opener that we are up and block
# to let tasks run **inside** the ws open block above.
nobsws._connected.set()
await trio.sleep_forever()

# ws open block end
# nursery block end
nobsws._connected = trio.Event()
if cs.cancelled_caught:
log.cancel(
f'{url} connection cancelled!'
)
# if wrapper cancelled us, we expect it to also
# have re-assigned a new event
assert (
nobsws._connected
and not nobsws._connected.is_set()
)

# -> from here, move to next reconnect attempt

else:
log.exception('ws connection closed by client...')


@acm
async def open_autorecon_ws(
url: str,

# TODO: proper type cannot smh
fixture: Optional[Callable] = None,
fixture: AsyncContextManager | None = None,

# time in sec
msg_recv_timeout: float = 3,

# count of the number of above timeouts before connection reset
reset_after: int = 3,

) -> AsyncGenerator[tuple[...], NoBsWs]:
"""Apparently we can QoS for all sorts of reasons..so catch em.
'''
An auto-reconnect websocket (wrapper API) around
``trio_websocket.open_websocket_url()`` providing automatic
re-connection on network errors, msg latency and thus roaming.
"""
async with AsyncExitStack() as stack:
ws = NoBsWs(url, stack, fixture=fixture)
await ws._connect()
Here we implement a re-connect websocket interface where a bg
nursery runs ``WebSocketConnection.receive_message()``s in a loop
and restarts the full http(s) handshake on catches of certain
connetivity errors, or some user defined recv timeout.
try:
yield ws
You can provide a ``fixture`` async-context-manager which will be
entered/exitted around each connection reset; eg. for (re)requesting
subscriptions without requiring streaming setup code to rerun.
'''
snd: trio.MemorySendChannel
rcv: trio.MemoryReceiveChannel
snd, rcv = trio.open_memory_channel(616)

async with trio.open_nursery() as n:
nobsws = NoBsWs(
url,
rcv,
msg_recv_timeout=msg_recv_timeout,
)
await n.start(
partial(
_reconnect_forever,
url,
snd,
nobsws,
fixture=fixture,
reset_after=reset_after,
)
)
await nobsws._connected.wait()
assert nobsws._cs
assert nobsws.connected()

try:
yield nobsws
finally:
await stack.aclose()
n.cancel_scope.cancel()


'''
Expand All @@ -192,7 +342,7 @@ class JSONRPCResult(Struct):
error: Optional[dict] = None


@asynccontextmanager
@acm
async def open_jsonrpc_session(
url: str,
start_id: int = 0,
Expand Down

0 comments on commit 2e2f49e

Please sign in to comment.