diff --git a/httpx/__init__.py b/httpx/__init__.py index ebacaad87e..8b2dda0fca 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -5,8 +5,7 @@ from .concurrency.base import ( BaseBackgroundManager, BasePoolSemaphore, - BaseReader, - BaseWriter, + BaseStream, ConcurrencyBackend, ) from .config import ( @@ -105,8 +104,7 @@ "TooManyRedirects", "WriteTimeout", "AsyncDispatcher", - "BaseReader", - "BaseWriter", + "BaseStream", "ConcurrencyBackend", "Dispatcher", "URL", diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio.py index 8fa19e9fdc..34d68c3830 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio.py @@ -1,5 +1,5 @@ """ -The `Reader` and `Writer` classes here provide a lightweight layer over +The `Stream` class here provides a lightweight layer over `asyncio.StreamReader` and `asyncio.StreamWriter`. Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`. @@ -14,18 +14,17 @@ import typing from types import TracebackType +from ..config import PoolLimits, TimeoutConfig +from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .base import ( BaseBackgroundManager, BasePoolSemaphore, BaseEvent, BaseQueue, - BaseReader, - BaseWriter, + BaseStream, ConcurrencyBackend, TimeoutFlag, ) -from ..config import PoolLimits, TimeoutConfig -from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout SSL_MONKEY_PATCH_APPLIED = False @@ -51,13 +50,29 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore MonkeyPatch.write = _fixed_write -class Reader(BaseReader): +class Stream(BaseStream): def __init__( - self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig - ) -> None: + self, + stream_reader: asyncio.StreamReader, + stream_writer: asyncio.StreamWriter, + timeout: TimeoutConfig, + ): self.stream_reader = stream_reader + self.stream_writer = stream_writer self.timeout = timeout + def get_http_version(self) -> str: + ssl_object = self.stream_writer.get_extra_info("ssl_object") + + if ssl_object is None: + return "HTTP/1.1" + + ident = ssl_object.selected_alpn_protocol() + if ident is None: + ident = ssl_object.selected_npn_protocol() + + return "HTTP/2" if ident == "h2" else "HTTP/1.1" + async def read( self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None ) -> bytes: @@ -78,15 +93,6 @@ async def read( return data - def is_connection_dropped(self) -> bool: - return self.stream_reader.at_eof() - - -class Writer(BaseWriter): - def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig): - self.stream_writer = stream_writer - self.timeout = timeout - def write_no_block(self, data: bytes) -> None: self.stream_writer.write(data) # pragma: nocover @@ -114,6 +120,9 @@ async def write( if should_raise: raise WriteTimeout() from None + def is_connection_dropped(self) -> bool: + return self.stream_reader.at_eof() + async def close(self) -> None: self.stream_writer.close() @@ -172,7 +181,7 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> BaseStream: try: stream_reader, stream_writer = await asyncio.wait_for( # type: ignore asyncio.open_connection(hostname, port, ssl=ssl_context), @@ -181,19 +190,9 @@ async def connect( except asyncio.TimeoutError: raise ConnectTimeout() - ssl_object = stream_writer.get_extra_info("ssl_object") - if ssl_object is None: - ident = "http/1.1" - else: - ident = ssl_object.selected_alpn_protocol() - if ident is None: - ident = ssl_object.selected_npn_protocol() - - reader = Reader(stream_reader=stream_reader, timeout=timeout) - writer = Writer(stream_writer=stream_writer, timeout=timeout) - http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1" - - return reader, writer, http_version + return Stream( + stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout + ) async def run_in_threadpool( self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any diff --git a/httpx/concurrency/base.py b/httpx/concurrency/base.py index 077961d207..45785df1c3 100644 --- a/httpx/concurrency/base.py +++ b/httpx/concurrency/base.py @@ -37,29 +37,21 @@ def set_write_timeouts(self) -> None: self.raise_on_write_timeout = True -class BaseReader: +class BaseStream: """ - A stream reader. Abstracts away any asyncio-specific interfaces - into a more generic base class, that we can use with alternate - backend, or for stand-alone test cases. + A stream with read/write operations. Abstracts away any asyncio-specific + interfaces into a more generic base class, that we can use with alternate + backends, or for stand-alone test cases. """ + def get_http_version(self) -> str: + raise NotImplementedError() # pragma: no cover + async def read( self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None ) -> bytes: raise NotImplementedError() # pragma: no cover - def is_connection_dropped(self) -> bool: - raise NotImplementedError() # pragma: no cover - - -class BaseWriter: - """ - A stream writer. Abstracts away any asyncio-specific interfaces - into a more generic base class, that we can use with alternate - backend, or for stand-alone test cases. - """ - def write_no_block(self, data: bytes) -> None: raise NotImplementedError() # pragma: no cover @@ -69,6 +61,9 @@ async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None: async def close(self) -> None: raise NotImplementedError() # pragma: no cover + def is_connection_dropped(self) -> bool: + raise NotImplementedError() # pragma: no cover + class BaseQueue: """ @@ -118,7 +113,7 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> BaseStream: raise NotImplementedError() # pragma: no cover def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index 36f96f8065..0e9819cb98 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -79,24 +79,24 @@ async def connect( else: on_release = functools.partial(self.release_func, self) - reader, writer, http_version = await self.backend.connect( - host, port, ssl_context, timeout - ) + stream = await self.backend.connect(host, port, ssl_context, timeout) + http_version = stream.get_http_version() + if http_version == "HTTP/2": self.h2_connection = HTTP2Connection( - reader, writer, self.backend, on_release=on_release + stream, self.backend, on_release=on_release ) else: assert http_version == "HTTP/1.1" self.h11_connection = HTTP11Connection( - reader, writer, self.backend, on_release=on_release + stream, self.backend, on_release=on_release ) async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]: if not self.origin.is_ssl: return None - # Run the SSL loading in a threadpool, since it may makes disk accesses. + # Run the SSL loading in a threadpool, since it may make disk accesses. return await self.backend.run_in_threadpool( ssl.load_ssl_context, self.http_versions ) diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 554591f804..236c81e2bb 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import h11 -from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..models import AsyncRequest, AsyncResponse @@ -27,13 +27,11 @@ class HTTP11Connection: def __init__( self, - reader: BaseReader, - writer: BaseWriter, + stream: BaseStream, backend: ConcurrencyBackend, on_release: typing.Optional[OnReleaseCallback] = None, ): - self.reader = reader - self.writer = writer + self.stream = stream self.backend = backend self.on_release = on_release self.h11_state = h11.Connection(our_role=h11.CLIENT) @@ -67,7 +65,7 @@ async def close(self) -> None: except h11.LocalProtocolError: # pragma: no cover # Premature client disconnect pass - await self.writer.close() + await self.stream.close() async def _send_request( self, request: AsyncRequest, timeout: TimeoutConfig = None @@ -111,7 +109,7 @@ async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> N drain before returning. """ bytes_to_send = self.h11_state.send(event) - await self.writer.write(bytes_to_send, timeout) + await self.stream.write(bytes_to_send, timeout) async def _receive_response( self, timeout: TimeoutConfig = None @@ -154,7 +152,7 @@ async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event: event = self.h11_state.next_event() if event is h11.NEED_DATA: try: - data = await self.reader.read( + data = await self.stream.read( self.READ_NUM_BYTES, timeout, flag=self.timeout_flag ) except OSError: # pragma: nocover @@ -184,4 +182,4 @@ def is_closed(self) -> bool: return self.h11_state.our_state in (h11.CLOSED, h11.ERROR) def is_connection_dropped(self) -> bool: - return self.reader.is_connection_dropped() + return self.stream.is_connection_dropped() diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index bf258e3a7a..0a698f35f4 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -4,7 +4,7 @@ import h2.connection import h2.events -from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag +from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..models import AsyncRequest, AsyncResponse @@ -14,13 +14,11 @@ class HTTP2Connection: def __init__( self, - reader: BaseReader, - writer: BaseWriter, + stream: BaseStream, backend: ConcurrencyBackend, on_release: typing.Callable = None, ): - self.reader = reader - self.writer = writer + self.stream = stream self.backend = backend self.on_release = on_release self.h2_state = h2.connection.H2Connection() @@ -58,12 +56,12 @@ async def send( ) async def close(self) -> None: - await self.writer.close() + await self.stream.close() def initiate_connection(self) -> None: self.h2_state.initiate_connection() data_to_send = self.h2_state.data_to_send() - self.writer.write_no_block(data_to_send) + self.stream.write_no_block(data_to_send) self.initialized = True async def send_headers( @@ -78,7 +76,7 @@ async def send_headers( ] + [(k, v) for k, v in request.headers.raw if k != b"host"] self.h2_state.send_headers(stream_id, headers) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) return stream_id async def send_request_data( @@ -104,12 +102,12 @@ async def send_data( chunk = data[idx : idx + chunk_size] self.h2_state.send_data(stream_id, chunk) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None: self.h2_state.end_stream(stream_id) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) async def receive_response( self, stream_id: int, timeout: TimeoutConfig = None @@ -150,14 +148,14 @@ async def receive_event( ) -> h2.events.Event: while not self.events[stream_id]: flag = self.timeout_flags[stream_id] - data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag) + data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag) events = self.h2_state.receive_data(data) for event in events: if getattr(event, "stream_id", 0): self.events[event.stream_id].append(event) data_to_send = self.h2_state.data_to_send() - await self.writer.write(data_to_send, timeout) + await self.stream.write(data_to_send, timeout) return self.events[stream_id].pop(0) @@ -173,4 +171,4 @@ def is_closed(self) -> bool: return False def is_connection_dropped(self) -> bool: - return self.reader.is_connection_dropped() + return self.stream.is_connection_dropped() diff --git a/tests/dispatch/utils.py b/tests/dispatch/utils.py index b5aac85037..3315135797 100644 --- a/tests/dispatch/utils.py +++ b/tests/dispatch/utils.py @@ -6,7 +6,7 @@ import h2.connection import h2.events -from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig +from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig class MockHTTP2Backend(AsyncioBackend): @@ -20,16 +20,12 @@ async def connect( port: int, ssl_context: typing.Optional[ssl.SSLContext], timeout: TimeoutConfig, - ) -> typing.Tuple[BaseReader, BaseWriter, str]: + ) -> BaseStream: self.server = MockHTTP2Server(self.app) - return self.server, self.server, "HTTP/2" + return self.server -class MockHTTP2Server(BaseReader, BaseWriter): - """ - This class exposes Reader and Writer style interfaces. - """ - +class MockHTTP2Server(BaseStream): def __init__(self, app): config = h2.config.H2Configuration(client_side=False) self.conn = h2.connection.H2Connection(config=config) @@ -38,15 +34,16 @@ def __init__(self, app): self.requests = {} self.close_connection = False - # BaseReader interface + # Stream interface + + def get_http_version(self) -> str: + return "HTTP/2" async def read(self, n, timeout, flag=None) -> bytes: await asyncio.sleep(0) send, self.buffer = self.buffer[:n], self.buffer[n:] return send - # BaseWriter interface - def write_no_block(self, data: bytes) -> None: events = self.conn.receive_data(data) self.buffer += self.conn.data_to_send()