From eaee2372cfd06849d5036097f34422b30a82fca8 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 20 Aug 2019 22:53:47 +0200 Subject: [PATCH] Unify BaseReader and BaseWriter as BaseStream --- httpx/__init__.py | 6 ++---- httpx/concurrency/asyncio.py | 40 +++++++++++++++++------------------- httpx/concurrency/base.py | 24 ++++++++-------------- httpx/dispatch/connection.py | 6 +++--- httpx/dispatch/http11.py | 16 +++++++-------- httpx/dispatch/http2.py | 24 ++++++++++------------ tests/dispatch/utils.py | 12 +++++------ 7 files changed, 55 insertions(+), 73 deletions(-) diff --git a/httpx/__init__.py b/httpx/__init__.py index ebacaad87e..8b2dda0fca 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -5,8 +5,7 @@ from .concurrency.base import ( BaseBackgroundManager, BasePoolSemaphore, - BaseReader, - BaseWriter, + BaseStream, ConcurrencyBackend, ) from .config import ( @@ -105,8 +104,7 @@ "TooManyRedirects", "WriteTimeout", "AsyncDispatcher", - "BaseReader", - "BaseWriter", + "BaseStream", "ConcurrencyBackend", "Dispatcher", "URL", diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index f378e4d940..883b08fa61 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -1,5 +1,5 @@ """ -The `Reader` and `Writer` classes here provide a lightweight layer over +The `Stream` class here provides a lightweight layer over `asyncio.StreamReader` and `asyncio.StreamWriter`. Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`. @@ -14,17 +14,16 @@ import typing from types import TracebackType +from ..config import PoolLimits, TimeoutConfig +from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import ( BaseBackgroundManager, BasePoolSemaphore, BaseQueue, - BaseReader, - BaseWriter, + BaseStream, ConcurrencyBackend, TimeoutFlag, ) -from ..config import PoolLimits, TimeoutConfig -from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout SSL_MONKEY_PATCH_APPLIED = False @@ -50,11 +49,15 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore MonkeyPatch.write = _fixed_write -class Reader(BaseReader): +class Stream(BaseStream): def __init__( - self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig - ) -> None: + self, + stream_reader: asyncio.StreamReader, + stream_writer: asyncio.StreamWriter, + timeout: TimeoutConfig, + ): self.stream_reader = stream_reader + self.stream_writer = stream_writer self.timeout = timeout async def read( @@ -77,15 +80,6 @@ async def read( return data - def is_connection_dropped(self) -> bool: - return self.stream_reader.at_eof() - - -class Writer(BaseWriter): - def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): - self.stream_writer = stream_writer - self.timeout = timeout - def write_no_block(self, data: bytes) -> None: self.stream_writer.write(data) # pragma: nocover @@ -113,6 +107,9 @@ async def write( if should_raise: raise WriteTimeout() from None + def is_connection_dropped(self) -> bool: + return self.stream_reader.at_eof() + async def close(self) -> None: self.stream_writer.close() @@ -171,7 +168,7 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> typing.Tuple[BaseStream, str]: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore asyncio.open_connection(hostname, port, ssl=ssl_context), @@ -188,11 +185,12 @@ async def connect( if ident is None: ident = ssl_object.selected_npn_protocol() - reader = Reader(stream_reader=stream_reader, timeout=timeout) - writer = Writer(stream_writer=stream_writer, timeout=timeout) + stream = Stream( + stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout + ) http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1" - return reader, writer, http_version + return stream, http_version async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 1a6842b852..20b511864c 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -37,11 +37,11 @@ def set_write_timeouts(self) -> None: self.raise_on_write_timeout = True -class BaseReader: +class BaseStream: """ - A stream reader. Abstracts away any asyncio-specific interfaces - into a more generic base class, that we can use with alternate - backend, or for stand-alone test cases. + A stream with read/write operations. Abstracts away any asyncio-specific + interfaces into a more generic base class, that we can use with alternate + backends, or for stand-alone test cases. """ async def read( @@ -49,17 +49,6 @@ async def read( ) -> bytes: raise NotImplementedError() # pragma: no cover - def is_connection_dropped(self) -> bool: - raise NotImplementedError() # pragma: no cover - - -class BaseWriter: - """ - A stream writer. Abstracts away any asyncio-specific interfaces - into a more generic base class, that we can use with alternate - backend, or for stand-alone test cases. - """ - def write_no_block(self, data: bytes) -> None: raise NotImplementedError() # pragma: no cover @@ -69,6 +58,9 @@ async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None: async def close(self) -> None: raise NotImplementedError() # pragma: no cover + def is_connection_dropped(self) -> bool: + raise NotImplementedError() # pragma: no cover + class BaseQueue: """ @@ -103,7 +95,7 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> typing.Tuple[BaseStream, str]: raise NotImplementedError() # pragma: no cover def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 36f96f8065..7f0d14eeb2 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -79,17 +79,17 @@ async def connect( else: on_release = functools.partial(self.release_func, self) - reader, writer, http_version = await self.backend.connect( + stream, http_version = await self.backend.connect( host, port, ssl_context, timeout ) if http_version == "HTTP/2": self.h2_connection = HTTP2Connection( - reader, writer, self.backend, on_release=on_release + stream, self.backend, on_release=on_release ) else: assert http_version == "HTTP/1.1" self.h11_connection = HTTP11Connection( - reader, writer, self.backend, on_release=on_release + stream, self.backend, on_release=on_release ) async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]: diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 554591f804..236c81e2bb 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import h11 -from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..models import AsyncRequest, AsyncResponse @@ -27,13 +27,11 @@ class HTTP11Connection: def __init__( self, - reader: BaseReader, - writer: BaseWriter, + stream: BaseStream, backend: ConcurrencyBackend, on_release: typing.Optional[OnReleaseCallback] = None, ): - self.reader = reader - self.writer = writer + self.stream = stream self.backend = backend self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) @@ -67,7 +65,7 @@ async def close(self) -> None: except h11.LocalProtocolError: # pragma: no cover # Premature client disconnect pass - await self.writer.close() + await self.stream.close() async def _send_request( self, request: AsyncRequest, timeout: TimeoutConfig = None @@ -111,7 +109,7 @@ async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> N drain before returning. """ bytes_to_send = self.h11_state.send(event) - await self.writer.write(bytes_to_send, timeout) + await self.stream.write(bytes_to_send, timeout) async def _receive_response( self, timeout: TimeoutConfig = None @@ -154,7 +152,7 @@ async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event: event = self.h11_state.next_event() if event is h11.NEED_DATA: try: - data = await self.reader.read( + data = await self.stream.read( self.READ_NUM_BYTES, timeout, flag=self.timeout_flag ) except OSError: # pragma: nocover @@ -184,4 +182,4 @@ def is_closed(self) -> bool: return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) def is_connection_dropped(self) -> bool: - return self.reader.is_connection_dropped() + return self.stream.is_connection_dropped() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index bf258e3a7a..0a698f35f4 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -4,7 +4,7 @@ import h2.connection import h2.events -from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..models import AsyncRequest, AsyncResponse @@ -14,13 +14,11 @@ class HTTP2Connection: def __init__( self, - reader: BaseReader, - writer: BaseWriter, + stream: BaseStream, backend: ConcurrencyBackend, on_release: typing.Callable = None, ): - self.reader = reader - self.writer = writer + self.stream = stream self.backend = backend self.on_release = on_release self.h2_state = h2.connection.H2Connection() @@ -58,12 +56,12 @@ async def send( ) async def close(self) -> None: - await self.writer.close() + await self.stream.close() def initiate_connection(self) -> None: self.h2_state.initiate_connection() data_to_send = self.h2_state.data_to_send() - self.writer.write_no_block(data_to_send) + self.stream.write_no_block(data_to_send) self.initialized = True async def send_headers( @@ -78,7 +76,7 @@ async def send_headers( ] + [(k, v) for k, v in request.headers.raw if k != b"host"] self.h2_state.send_headers(stream_id, headers) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) return stream_id async def send_request_data( @@ -104,12 +102,12 @@ async def send_data( chunk = data[idx : idx + chunk_size] self.h2_state.send_data(stream_id, chunk) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None: self.h2_state.end_stream(stream_id) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) async def receive_response( self, stream_id: int, timeout: TimeoutConfig = None @@ -150,14 +148,14 @@ async def receive_event( ) -> h2.events.Event: while not self.events[stream_id]: flag = self.timeout_flags[stream_id] - data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag) + data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag) events = self.h2_state.receive_data(data) for event in events: if getattr(event, "stream_id", 0): self.events[event.stream_id].append(event) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) return self.events[stream_id].pop(0) @@ -173,4 +171,4 @@ def is_closed(self) -> bool: return False def is_connection_dropped(self) -> bool: - return self.reader.is_connection_dropped() + return self.stream.is_connection_dropped() diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index b5aac85037..8198475d71 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -6,7 +6,7 @@ import h2.connection import h2.events -from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig +from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig class MockHTTP2Backend(AsyncioBackend): @@ -20,12 +20,12 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> typing.Tuple[BaseStream, str]: self.server = MockHTTP2Server(self.app) - return self.server, self.server, "HTTP/2" + return self.server, "HTTP/2" -class MockHTTP2Server(BaseReader, BaseWriter): +class MockHTTP2Server(BaseStream): """ This class exposes Reader and Writer style interfaces. """ @@ -38,15 +38,13 @@ def __init__(self, app): self.requests = {} self.close_connection = False - # BaseReader interface + # Stream interface async def read(self, n, timeout, flag=None) -> bytes: await asyncio.sleep(0) send, self.buffer = self.buffer[:n], self.buffer[n:] return send - # BaseWriter interface - def write_no_block(self, data: bytes) -> None: events = self.conn.receive_data(data) self.buffer += self.conn.data_to_send()