Skip to content

Commit

Permalink
Unify BaseReader and BaseWriter as BaseStream
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Aug 21, 2019
1 parent c0554e9 commit eaee237
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 73 deletions.
6 changes: 2 additions & 4 deletions httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
)
from .config import (
Expand Down Expand Up @@ -105,8 +104,7 @@
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
"BaseReader",
"BaseWriter",
"BaseStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
Expand Down
40 changes: 19 additions & 21 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The `Reader` and `Writer` classes here provide a lightweight layer over
The `Stream` class here provides a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.
Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
Expand All @@ -14,17 +14,16 @@
import typing
from types import TracebackType

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseQueue,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
TimeoutFlag,
)
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout

SSL_MONKEY_PATCH_APPLIED = False

Expand All @@ -50,11 +49,15 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore
MonkeyPatch.write = _fixed_write


class Reader(BaseReader):
class Stream(BaseStream):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
) -> None:
self,
stream_reader: asyncio.StreamReader,
stream_writer: asyncio.StreamWriter,
timeout: TimeoutConfig,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.timeout = timeout

async def read(
Expand All @@ -77,15 +80,6 @@ async def read(

return data

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()


class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
self.stream_writer = stream_writer
self.timeout = timeout

def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover

Expand Down Expand Up @@ -113,6 +107,9 @@ async def write(
if should_raise:
raise WriteTimeout() from None

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()

async def close(self) -> None:
self.stream_writer.close()

Expand Down Expand Up @@ -171,7 +168,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> typing.Tuple[BaseStream, str]:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
Expand All @@ -188,11 +185,12 @@ async def connect(
if ident is None:
ident = ssl_object.selected_npn_protocol()

reader = Reader(stream_reader=stream_reader, timeout=timeout)
writer = Writer(stream_writer=stream_writer, timeout=timeout)
stream = Stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"

return reader, writer, http_version
return stream, http_version

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
Expand Down
24 changes: 8 additions & 16 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,18 @@ def set_write_timeouts(self) -> None:
self.raise_on_write_timeout = True


class BaseReader:
class BaseStream:
"""
A stream reader. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
A stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseWriter:
"""
A stream writer. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
"""

def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover

Expand All @@ -69,6 +58,9 @@ async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
async def close(self) -> None:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseQueue:
"""
Expand Down Expand Up @@ -103,7 +95,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> typing.Tuple[BaseStream, str]:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
6 changes: 3 additions & 3 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ async def connect(
else:
on_release = functools.partial(self.release_func, self)

reader, writer, http_version = await self.backend.connect(
stream, http_version = await self.backend.connect(
host, port, ssl_context, timeout
)
if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)

async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
Expand Down
16 changes: 7 additions & 9 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h11

from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse

Expand All @@ -27,13 +27,11 @@ class HTTP11Connection:

def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
Expand Down Expand Up @@ -67,7 +65,7 @@ async def close(self) -> None:
except h11.LocalProtocolError: # pragma: no cover
# Premature client disconnect
pass
await self.writer.close()
await self.stream.close()

async def _send_request(
self, request: AsyncRequest, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -111,7 +109,7 @@ async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> N
drain before returning.
"""
bytes_to_send = self.h11_state.send(event)
await self.writer.write(bytes_to_send, timeout)
await self.stream.write(bytes_to_send, timeout)

async def _receive_response(
self, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -154,7 +152,7 @@ async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
event = self.h11_state.next_event()
if event is h11.NEED_DATA:
try:
data = await self.reader.read(
data = await self.stream.read(
self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
)
except OSError: # pragma: nocover
Expand Down Expand Up @@ -184,4 +182,4 @@ def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()
24 changes: 11 additions & 13 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import h2.connection
import h2.events

from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse

Expand All @@ -14,13 +14,11 @@ class HTTP2Connection:

def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
Expand Down Expand Up @@ -58,12 +56,12 @@ async def send(
)

async def close(self) -> None:
await self.writer.close()
await self.stream.close()

def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write_no_block(data_to_send)
self.stream.write_no_block(data_to_send)
self.initialized = True

async def send_headers(
Expand All @@ -78,7 +76,7 @@ async def send_headers(
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
return stream_id

async def send_request_data(
Expand All @@ -104,12 +102,12 @@ async def send_data(
chunk = data[idx : idx + chunk_size]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

async def receive_response(
self, stream_id: int, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -150,14 +148,14 @@ async def receive_event(
) -> h2.events.Event:
while not self.events[stream_id]:
flag = self.timeout_flags[stream_id]
data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)

data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

return self.events[stream_id].pop(0)

Expand All @@ -173,4 +171,4 @@ def is_closed(self) -> bool:
return False

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()
12 changes: 5 additions & 7 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import h2.connection
import h2.events

from httpx import AsyncioBackend, BaseReader, BaseWriter, Request, TimeoutConfig
from httpx import AsyncioBackend, BaseStream, Request, TimeoutConfig


class MockHTTP2Backend(AsyncioBackend):
Expand All @@ -20,12 +20,12 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> typing.Tuple[BaseStream, str]:
self.server = MockHTTP2Server(self.app)
return self.server, self.server, "HTTP/2"
return self.server, "HTTP/2"


class MockHTTP2Server(BaseReader, BaseWriter):
class MockHTTP2Server(BaseStream):
"""
This class exposes Reader and Writer style interfaces.
"""
Expand All @@ -38,15 +38,13 @@ def __init__(self, app):
self.requests = {}
self.close_connection = False

# BaseReader interface
# Stream interface

async def read(self, n, timeout, flag=None) -> bytes:
await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
return send

# BaseWriter interface

def write_no_block(self, data: bytes) -> None:
events = self.conn.receive_data(data)
self.buffer += self.conn.data_to_send()
Expand Down

0 comments on commit eaee237

Please sign in to comment.