Skip to content

Commit

Permalink
document memoryview usage, minor frame_writer.write_frame refactor (#384
Browse files Browse the repository at this point in the history
)

* document memory_view usage, refactor frame_writer.write_frame

* improve test for changing frame_max in write_frame

* add integration test for write_frame/send_heartbeat
  • Loading branch information
pawl authored Dec 24, 2021
1 parent 935a06b commit 892389c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
19 changes: 11 additions & 8 deletions amqp/method_framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def buf(self):
@buf.setter
def buf(self, buf):
self._buf = buf
# Using a memoryview allows slicing without copying underlying data.
# Slicing this is much faster than slicing the bytearray directly.
# More details: https://stackoverflow.com/a/34257357
self.view = memoryview(buf)


Expand All @@ -107,13 +110,6 @@ def frame_writer(connection, transport,

def write_frame(type_, channel, method_sig, args, content):
chunk_size = connection.frame_max - 8
# frame_max can be updated via connection._on_tune. If
# it became larger, then we need to resize the buffer
# to prevent overflow.
if chunk_size > len(buffer_store.buf):
buffer_store.buf = bytearray(chunk_size)
buf = buffer_store.buf
view = buffer_store.view
offset = 0
properties = None
args = str_to_bytes(args)
Expand Down Expand Up @@ -155,6 +151,13 @@ def write_frame(type_, channel, method_sig, args, content):
frame, 0xce))

else:
# frame_max can be updated via connection._on_tune. If
# it became larger, then we need to resize the buffer
# to prevent overflow.
if chunk_size > len(buffer_store.buf):
buffer_store.buf = bytearray(chunk_size)
buf = buffer_store.buf

# ## FAST: pack into buffer and single write
frame = (b''.join([pack('>HH', *method_sig), args])
if type_ == 1 else b'')
Expand All @@ -180,7 +183,7 @@ def write_frame(type_, channel, method_sig, args, content):
3, channel, framelen, body, 0xce)
offset += 8 + framelen

write(view[:offset])
write(buffer_store.view[:offset])

connection.bytes_sent += 1
return write_frame
11 changes: 11 additions & 0 deletions t/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ def test_connection_closed_by_broker(self):
)
callback_mock.assert_called_once_with()

def test_send_heartbeat(self):
"""The send_heartbeat method writes the expected output."""
conn = Connection()
with patch.object(conn, 'Transport') as transport_mock:
handshake(conn, transport_mock)
transport_mock().write.reset_mock()
conn.send_heartbeat()
transport_mock().write.assert_called_once_with(
memoryview(bytearray(b'\x08\x00\x00\x00\x00\x00\x00\xce'))
)


class test_channel:
# Integration tests. Tests verify the correctness of communication between
Expand Down
27 changes: 20 additions & 7 deletions t/unit/test_method_framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,24 @@ def test_write_non_utf8(self):
assert 'body'.encode('utf-16') in memory.tobytes()
assert msg.properties['content_encoding'] == 'utf-16'

def test_frame_max_update(self):
msg = Message(body='t' * (self.connection.frame_max + 10))
frame = 2, 1, spec.Basic.Publish, b'x' * 10, msg
def test_write_frame__fast__buffer_store_resize(self):
"""The buffer_store is resized when the connection's frame_max is increased."""
small_msg = Message(body='t')
small_frame = 2, 1, spec.Basic.Publish, b'x' * 10, small_msg
self.g(*small_frame)
self.write.assert_called_once()
write_arg = self.write.call_args[0][0]
assert isinstance(write_arg, memoryview)
assert len(write_arg) < self.connection.frame_max
self.connection.reset_mock()

# write a larger message to the same frame_writer after increasing frame_max
large_msg = Message(body='t' * (self.connection.frame_max + 10))
large_frame = 2, 1, spec.Basic.Publish, b'x' * 10, large_msg
original_frame_max = self.connection.frame_max
self.connection.frame_max += 100
self.g(*frame)
self.write.assert_called()
memory = self.write.call_args[0][0]
assert isinstance(memory, memoryview)
self.g(*large_frame)
self.write.assert_called_once()
write_arg = self.write.call_args[0][0]
assert isinstance(write_arg, memoryview)
assert len(write_arg) > original_frame_max

0 comments on commit 892389c

Please sign in to comment.