Skip to content
This repository has been archived by the owner on Sep 13, 2024. It is now read-only.

Add support for driving BufferedProtocol instances using sock_recv_into #60

Merged
merged 8 commits into from
May 4, 2024
1 change: 1 addition & 0 deletions changes/60.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for driving ``BufferedProtocol`` instances using ``sock_recv_into`` was added.
8 changes: 8 additions & 0 deletions src/gbulb/glib_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,14 @@ def read_func(channel, nbytes):

return self._channel_read(channel, nbytes, read_func)

def sock_recv_into(self, sock, buf, flags=0):
channel = self._channel_from_socket(sock)

def read_func(channel, nbytes):
return sock.recv_into(buf, flags)

return self._channel_read(channel, len(buf), read_func)

def sock_recvfrom(self, sock, nbytes, flags=0):
channel = self._channel_from_socket(sock)

Expand Down
94 changes: 68 additions & 26 deletions src/gbulb/transports.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import collections
import io
import socket
import subprocess
from asyncio import CancelledError, InvalidStateError, base_subprocess, transports
Expand All @@ -13,12 +15,12 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None):

self._loop = loop
self._sock = sock
self._protocol = protocol
self._server = server
self._closing = False
self._closing_delayed = False
self._closed = False
self._cancelable = set()
self.set_protocol(protocol)

if sock is not None:
self._loop._transports[sock.fileno()] = self
Expand Down Expand Up @@ -83,15 +85,23 @@ def _force_close_async(self, exc):


class ReadTransport(BaseTransport, transports.ReadTransport):
max_size = 256 * 1024
max_size = io.DEFAULT_BUFFER_SIZE

def __init__(self, *args, **kwargs):
BaseTransport.__init__(self, *args, **kwargs)

self._paused = False
self._read_fut = None
self._read_buffer = None
self._alloc_read_buffers = False

BaseTransport.__init__(self, *args, **kwargs)

self._loop.call_soon(self._loop_reading)

def set_protocol(self, protocol):
if hasattr(asyncio, "BufferedProtocol"): # Python 3.7+
self._alloc_read_buffers = isinstance(protocol, asyncio.BufferedProtocol)
super().set_protocol(protocol)

def pause_reading(self):
if self._closing:
raise RuntimeError("Cannot pause_reading() when closing")
Expand Down Expand Up @@ -120,22 +130,33 @@ def close(self):
super().close()

def _create_read_future(self, size):
return self._loop.sock_recv(self._sock, size)
if self._alloc_read_buffers:
self._read_buffer = self._protocol.get_buffer(size)
return self._loop.sock_recv_into(self._sock, self._read_buffer)
else:
return self._loop.sock_recv(self._sock, size)

def _submit_read_data(self, data):
if data:
self._protocol.data_received(data)
if data != b"" and data != 0:
if self._alloc_read_buffers:
assert isinstance(data, int) # Actually `nbytes`
self._protocol.buffer_updated(data)
self._read_buffer = None
else:
assert isinstance(data, bytes)
self._protocol.data_received(data)
else:
self._read_buffer = None
keep_open = self._protocol.eof_received()
if not keep_open:
self.close()

def _loop_reading(self, fut=None):
if self._paused:
return
data = None

try:
data = None
if fut is not None:
assert self._read_fut is fut or (
self._read_fut is None and self._closing
Expand All @@ -150,7 +171,10 @@ def _loop_reading(self, fut=None):
data = None
return

if data == b"":
if data is not None:
self._submit_read_data(data)

if data == b"" or data == 0:
# No need to reschedule on end-of-file
return

Expand All @@ -172,9 +196,6 @@ def _loop_reading(self, fut=None):
self._cancelable.add(self._read_fut)
else:
self._read_fut.add_done_callback(self._loop_reading)
finally:
if data is not None:
self._submit_read_data(data)


class WriteTransport(BaseTransport, transports._FlowControlMixin):
Expand All @@ -184,8 +205,8 @@ def __init__(self, loop, *args, **kwargs):
transports._FlowControlMixin.__init__(self, None, loop)
BaseTransport.__init__(self, loop, *args, **kwargs)

self._buffer = self._buffer_factory()
self._buffer_empty_callbacks = set()
self._write_buffer = self._buffer_factory()
self._drained_callbacks = set()
self._write_fut = None
self._eof_written = False

Expand All @@ -196,7 +217,7 @@ def can_write_eof(self):
return True

def get_write_buffer_size(self):
return len(self._buffer)
return len(self._write_buffer)

def _close_write(self):
if self._write_fut is not None:
Expand All @@ -206,7 +227,7 @@ def transport_write_done_callback():
self._closing_delayed = False
self.close()

self._buffer_empty_callbacks.add(transport_write_done_callback)
self._drained_callbacks.add(transport_write_done_callback)

def close(self):
self._close_write()
Expand All @@ -231,12 +252,12 @@ def _create_write_future(self, data):
return self._loop.sock_sendall(self._sock, data)

def _buffer_add_data(self, data):
self._buffer.extend(data)
self._write_buffer.extend(data)

def _buffer_pop_data(self):
if len(self._buffer) > 0:
data = self._buffer
self._buffer = bytearray()
if len(self._write_buffer) > 0:
data = self._write_buffer
self._write_buffer = self._buffer_factory()
return data
else:
return None
Expand All @@ -257,10 +278,10 @@ def _loop_writing(self, fut=None, data=None):
data = self._buffer_pop_data()

if not data:
if len(self._buffer_empty_callbacks) > 0:
for callback in self._buffer_empty_callbacks:
if len(self._drained_callbacks) > 0:
for callback in self._drained_callbacks:
callback()
self._buffer_empty_callbacks.clear()
self._drained_callbacks.clear()

self._maybe_resume_protocol()
else:
Expand Down Expand Up @@ -351,11 +372,11 @@ def _create_write_future(self, args):
def _buffer_add_data(self, args):
(data, addr) = args

self._buffer.append((bytes(data), addr))
self._write_buffer.append((bytes(data), addr))

def _buffer_pop_data(self):
if len(self._buffer) > 0:
return self._buffer.popleft()
if len(self._write_buffer) > 0:
return self._write_buffer.popleft()
else:
return None

Expand Down Expand Up @@ -385,8 +406,29 @@ def __init__(self, loop, channel, protocol, waiter, extra):
super().__init__(loop, None, protocol, waiter, extra)

def _create_read_future(self, size):
if self._alloc_read_buffers:
self._read_buffer = self._protocol.get_buffer(size)
size = len(self._read_buffer)
return self._loop._channel_read(self._channel, size)

def _submit_read_data(self, data):
assert isinstance(data, bytes)
if data != b"" and data != 0:
if self._alloc_read_buffers:
# FIXME: GLib does not actually expose the equivalent to
# `recv_into` in its channel interface, so we have to
# add an extra copy here rather than avoiding one
self._read_buffer[0 : len(data)] = data
self._protocol.buffer_updated(len(data))
self._read_buffer = None
else:
self._protocol.data_received(data)
else:
self._read_buffer = None
keep_open = self._protocol.eof_received()
if not keep_open:
self.close()

def _force_close_async(self, exc):
try:
super()._force_close_async(exc)
Expand Down