Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to Trio 0.9.0 #5

Merged
merged 4 commits into from
Jan 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
'pytest-trio >= 0.3',
],
install_requires=[
'trio',
'trio >= 0.9.0',
],
packages=[
'trio_amqp',
Expand Down
35 changes: 20 additions & 15 deletions trio_amqp/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def __init__(self, channel, consumer_tag, **kwargs):

async def _data(self, channel, msg, env, prop):
if msg is None:
await self._q.put(None)
await self._chan_send.send(None)
else:
await self._q.put((msg, env, prop))
await self._chan_send.send((msg, env, prop))

if sys.version_info >= (3,5,3):
def __aiter__(self):
Expand All @@ -44,14 +44,15 @@ async def __aiter__(self):
return self

async def __anext__(self):
res = await self._q.get()
res = await self._chan_receive.receive()
if res is None:
raise StopAsyncIteration
return res

async def __aenter__(self):
await self.channel.basic_consume(self._data, consumer_tag=self.consumer_tag, **self.kwargs)
self._q = trio.Queue(30) # TODO: 2 + possible prefetch
# TODO: 2 + possible prefetch
self._chan_send, self._chan_receive = trio.open_memory_channel(30)
return self

async def __aexit__(self, *tb):
Expand All @@ -60,7 +61,8 @@ async def __aexit__(self, *tb):
await self.channel.basic_cancel(self.consumer_tag)
except AmqpClosedConnection:
pass
del self._q
del self._chan_send
del self._chan_receive
# these messages are not acknowledged, thus deleting the queue will
# not lose them

Expand All @@ -75,7 +77,6 @@ def __iter__(self):


class Channel:
_q = None # for returned messages

def __init__(self, protocol, channel_id):
self.protocol = protocol
Expand All @@ -97,9 +98,13 @@ def __init__(self, protocol, channel_id):
self._futures = {}
self._ctag_events = {}

self._chan_send = None
self._chan_receive = None

def __aiter__(self):
if self._q is None:
self._q = trio.Queue(30) # TODO: 2 + possible prefetch
if self._chan_send is None:
# TODO: 2 + possible prefetch
self._chan_send, self._chan_receive = trio.open_memory_channel(30)
return self

if sys.version_info < (3,5,3):
Expand All @@ -108,7 +113,7 @@ async def __aiter__(self):
return self._aiter()

async def __anext__(self):
res = await self._q.get()
res = await self._chan_receive.receive()
if res is None:
raise StopAsyncIteration
return res
Expand Down Expand Up @@ -149,8 +154,8 @@ def connection_closed(self, server_code=None, server_reason=None, exception=None

self.protocol.release_channel_id(self.channel_id)
self.close_event.set()
if self._q is not None:
self._q.put_nowait(None)
if self._chan_send is not None:
self._chan_send.send_nowait(None)

async def dispatch_frame(self, frame):
methods = {
Expand Down Expand Up @@ -271,8 +276,8 @@ async def close(self, reply_code=0, reply_text="Normal Shutdown"):
if not self.is_open:
raise exceptions.ChannelClosed("channel already closed or closing")
self.close_event.set()
if self._q is not None:
self._q.put_nowait(None)
if self._chan_send is not None:
self._chan_send.send_nowait(None)
frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id)
frame.declare_method(amqp_constants.CLASS_CHANNEL, amqp_constants.CHANNEL_CLOSE)
request = amqp_frame.AmqpEncoder()
Expand Down Expand Up @@ -946,11 +951,11 @@ async def basic_return(self, frame):
envelope = ReturnEnvelope(reply_code, reply_text,
exchange_name, routing_key)
properties = content_header_frame.properties
if self._q is None:
if self._chan_send is None:
# they have set mandatory bit, but havent added a callback
logger.warning("You don't iterate the channel for returned messages!")
else:
await self._q.put((body, envelope, properties))
await self._chan_send.send((body, envelope, properties))

async def basic_get(self, queue_name='', no_ack=False):
frame = amqp_frame.AmqpRequest(amqp_constants.TYPE_METHOD, self.channel_id)
Expand Down
14 changes: 7 additions & 7 deletions trio_amqp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def _drain(self):
async def _write_frame(self, frame, encoder, drain=True):
# Doesn't actually write frame, pushes it for _writer_loop task to
# pick it up.
await self._send_queue.put((frame, encoder))
await self._send_send_channel.send((frame, encoder))

@trio.hazmat.enable_ki_protection
async def _writer_loop(self, task_status=trio.TASK_STATUS_IGNORED):
Expand All @@ -216,15 +216,15 @@ async def _writer_loop(self, task_status=trio.TASK_STATUS_IGNORED):
timeout = inf

with trio.move_on_after(timeout) as timeout_scope:
frame, encoder = await self._send_queue.get()
frame, encoder = await self._send_receive_channel.receive()
if timeout_scope.cancelled_caught:
await self.send_heartbeat()
continue

f = frame.get_frame(encoder)
try:
await self._stream.send_all(f)
except (trio.BrokenStreamError,trio.ClosedStreamError):
except (trio.BrokenResourceError, trio.ClosedResourceError):
# raise exceptions.AmqpClosedConnection(self) from None
# the reader will raise the error also
return
Expand Down Expand Up @@ -258,7 +258,7 @@ async def aclose(self, no_wait=False):
encoder.write_short(0)
try:
await self._write_frame(frame, encoder)
except trio.ClosedStreamError:
except trio.BrokenResourceError:
pass
except Exception:
logger.exception("Error while closing")
Expand Down Expand Up @@ -315,7 +315,7 @@ async def __aenter__(self):
self.server_channel_max = None
self.channels_ids_ceil = 0
self.channels_ids_free = set()
self._send_queue = trio.Queue(1)
self._send_send_channel, self._send_receive_channel = trio.open_memory_channel(1)

if self._ssl:
if self._ssl is True:
Expand Down Expand Up @@ -423,7 +423,7 @@ async def get_frame(self):
frame = amqp_frame.AmqpResponse(self._stream)
try:
await frame.read_frame()
except trio.BrokenStreamError:
except trio.BrokenResourceError:
raise exceptions.AmqpClosedConnection(self) from None

return frame
Expand Down Expand Up @@ -511,7 +511,7 @@ async def _reader_loop(self, task_status=trio.TASK_STATUS_IGNORED):
with trio.fail_after(timeout):
try:
frame = await self.get_frame()
except trio.ClosedStreamError:
except (trio.BrokenResourceError, trio.ClosedResourceError):
# the stream is now *really* closed …
return
try:
Expand Down