From 892389c20747ad8d98d2f3fd02ee029685289fe6 Mon Sep 17 00:00:00 2001 From: Paul Brown Date: Fri, 24 Dec 2021 07:47:42 +0000 Subject: [PATCH] document memoryview usage, minor frame_writer.write_frame refactor (#384) * document memory_view usage, refactor frame_writer.write_frame * improve test for changing frame_max in write_frame * add integration test for write_frame/send_heartbeat --- amqp/method_framing.py | 19 +++++++++++-------- t/integration/test_integration.py | 11 +++++++++++ t/unit/test_method_framing.py | 27 ++++++++++++++++++++------- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/amqp/method_framing.py b/amqp/method_framing.py index 5fe05052..6c49833f 100644 --- a/amqp/method_framing.py +++ b/amqp/method_framing.py @@ -94,6 +94,9 @@ def buf(self): @buf.setter def buf(self, buf): self._buf = buf + # Using a memoryview allows slicing without copying underlying data. + # Slicing this is much faster than slicing the bytearray directly. + # More details: https://stackoverflow.com/a/34257357 self.view = memoryview(buf) @@ -107,13 +110,6 @@ def frame_writer(connection, transport, def write_frame(type_, channel, method_sig, args, content): chunk_size = connection.frame_max - 8 - # frame_max can be updated via connection._on_tune. If - # it became larger, then we need to resize the buffer - # to prevent overflow. - if chunk_size > len(buffer_store.buf): - buffer_store.buf = bytearray(chunk_size) - buf = buffer_store.buf - view = buffer_store.view offset = 0 properties = None args = str_to_bytes(args) @@ -155,6 +151,13 @@ def write_frame(type_, channel, method_sig, args, content): frame, 0xce)) else: + # frame_max can be updated via connection._on_tune. If + # it became larger, then we need to resize the buffer + # to prevent overflow. + if chunk_size > len(buffer_store.buf): + buffer_store.buf = bytearray(chunk_size) + buf = buffer_store.buf + # ## FAST: pack into buffer and single write frame = (b''.join([pack('>HH', *method_sig), args]) if type_ == 1 else b'') @@ -180,7 +183,7 @@ def write_frame(type_, channel, method_sig, args, content): 3, channel, framelen, body, 0xce) offset += 8 + framelen - write(view[:offset]) + write(buffer_store.view[:offset]) connection.bytes_sent += 1 return write_frame diff --git a/t/integration/test_integration.py b/t/integration/test_integration.py index d4414883..d0b67b92 100644 --- a/t/integration/test_integration.py +++ b/t/integration/test_integration.py @@ -416,6 +416,17 @@ def test_connection_closed_by_broker(self): ) callback_mock.assert_called_once_with() + def test_send_heartbeat(self): + """The send_heartbeat method writes the expected output.""" + conn = Connection() + with patch.object(conn, 'Transport') as transport_mock: + handshake(conn, transport_mock) + transport_mock().write.reset_mock() + conn.send_heartbeat() + transport_mock().write.assert_called_once_with( + memoryview(bytearray(b'\x08\x00\x00\x00\x00\x00\x00\xce')) + ) + class test_channel: # Integration tests. Tests verify the correctness of communication between diff --git a/t/unit/test_method_framing.py b/t/unit/test_method_framing.py index bd9c465a..96ecf03c 100644 --- a/t/unit/test_method_framing.py +++ b/t/unit/test_method_framing.py @@ -138,11 +138,24 @@ def test_write_non_utf8(self): assert 'body'.encode('utf-16') in memory.tobytes() assert msg.properties['content_encoding'] == 'utf-16' - def test_frame_max_update(self): - msg = Message(body='t' * (self.connection.frame_max + 10)) - frame = 2, 1, spec.Basic.Publish, b'x' * 10, msg + def test_write_frame__fast__buffer_store_resize(self): + """The buffer_store is resized when the connection's frame_max is increased.""" + small_msg = Message(body='t') + small_frame = 2, 1, spec.Basic.Publish, b'x' * 10, small_msg + self.g(*small_frame) + self.write.assert_called_once() + write_arg = self.write.call_args[0][0] + assert isinstance(write_arg, memoryview) + assert len(write_arg) < self.connection.frame_max + self.connection.reset_mock() + + # write a larger message to the same frame_writer after increasing frame_max + large_msg = Message(body='t' * (self.connection.frame_max + 10)) + large_frame = 2, 1, spec.Basic.Publish, b'x' * 10, large_msg + original_frame_max = self.connection.frame_max self.connection.frame_max += 100 - self.g(*frame) - self.write.assert_called() - memory = self.write.call_args[0][0] - assert isinstance(memory, memoryview) + self.g(*large_frame) + self.write.assert_called_once() + write_arg = self.write.call_args[0][0] + assert isinstance(write_arg, memoryview) + assert len(write_arg) > original_frame_max