Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify BaseReader and BaseWriter as BaseStream #255

Merged
merged 2 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions httpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from .concurrency.base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
)
from .config import (
Expand Down Expand Up @@ -105,8 +104,7 @@
"TooManyRedirects",
"WriteTimeout",
"AsyncDispatcher",
"BaseReader",
"BaseWriter",
"BaseStream",
"ConcurrencyBackend",
"Dispatcher",
"URL",
Expand Down
61 changes: 30 additions & 31 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The `Reader` and `Writer` classes here provide a lightweight layer over
The `Stream` class here provides a lightweight layer over
`asyncio.StreamReader` and `asyncio.StreamWriter`.

Similarly `PoolSemaphore` is a lightweight layer over `BoundedSemaphore`.
Expand All @@ -14,18 +14,17 @@
import typing
from types import TracebackType

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BasePoolSemaphore,
BaseEvent,
BaseQueue,
BaseReader,
BaseWriter,
BaseStream,
ConcurrencyBackend,
TimeoutFlag,
)
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout

SSL_MONKEY_PATCH_APPLIED = False

Expand All @@ -51,13 +50,29 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore
MonkeyPatch.write = _fixed_write


class Reader(BaseReader):
class Stream(BaseStream):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
) -> None:
self,
stream_reader: asyncio.StreamReader,
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
stream_writer: asyncio.StreamWriter,
timeout: TimeoutConfig,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
self.timeout = timeout

def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")

if ssl_object is None:
return "HTTP/1.1"

ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()

return "HTTP/2" if ident == "h2" else "HTTP/1.1"

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> bytes:
Expand All @@ -78,15 +93,6 @@ async def read(

return data

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()


class Writer(BaseWriter):
def __init__(self, stream_writer: asyncio.StreamWriter, timeout: TimeoutConfig):
self.stream_writer = stream_writer
self.timeout = timeout

def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover

Expand Down Expand Up @@ -114,6 +120,9 @@ async def write(
if should_raise:
raise WriteTimeout() from None

def is_connection_dropped(self) -> bool:
return self.stream_reader.at_eof()

async def close(self) -> None:
self.stream_writer.close()

Expand Down Expand Up @@ -172,7 +181,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> BaseStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
Expand All @@ -181,19 +190,9 @@ async def connect(
except asyncio.TimeoutError:
raise ConnectTimeout()

ssl_object = stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
ident = "http/1.1"
else:
ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()

reader = Reader(stream_reader=stream_reader, timeout=timeout)
writer = Writer(stream_writer=stream_writer, timeout=timeout)
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"

return reader, writer, http_version
return Stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
Expand Down
27 changes: 11 additions & 16 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,21 @@ def set_write_timeouts(self) -> None:
self.raise_on_write_timeout = True


class BaseReader:
class BaseStream:
"""
A stream reader. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
A stream with read/write operations. Abstracts away any asyncio-specific
interfaces into a more generic base class, that we can use with alternate
backends, or for stand-alone test cases.
"""

def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseWriter:
"""
A stream writer. Abstracts away any asyncio-specific interfaces
into a more generic base class, that we can use with alternate
backend, or for stand-alone test cases.
"""

def write_no_block(self, data: bytes) -> None:
raise NotImplementedError() # pragma: no cover

Expand All @@ -69,6 +61,9 @@ async def write(self, data: bytes, timeout: TimeoutConfig = None) -> None:
async def close(self) -> None:
raise NotImplementedError() # pragma: no cover

def is_connection_dropped(self) -> bool:
raise NotImplementedError() # pragma: no cover


class BaseQueue:
"""
Expand Down Expand Up @@ -118,7 +113,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseReader, BaseWriter, str]:
) -> BaseStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
12 changes: 6 additions & 6 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,24 @@ async def connect(
else:
on_release = functools.partial(self.release_func, self)

reader, writer, http_version = await self.backend.connect(
host, port, ssl_context, timeout
)
stream = await self.backend.connect(host, port, ssl_context, timeout)
http_version = stream.get_http_version()

if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)
else:
assert http_version == "HTTP/1.1"
self.h11_connection = HTTP11Connection(
reader, writer, self.backend, on_release=on_release
stream, self.backend, on_release=on_release
)

async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContext]:
if not self.origin.is_ssl:
return None

# Run the SSL loading in a threadpool, since it may makes disk accesses.
# Run the SSL loading in a threadpool, since it may make disk accesses.
return await self.backend.run_in_threadpool(
ssl.load_ssl_context, self.http_versions
)
Expand Down
16 changes: 7 additions & 9 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h11

from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse

Expand All @@ -27,13 +27,11 @@ class HTTP11Connection:

def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Optional[OnReleaseCallback] = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h11_state = h11.Connection(our_role=h11.CLIENT)
Expand Down Expand Up @@ -67,7 +65,7 @@ async def close(self) -> None:
except h11.LocalProtocolError: # pragma: no cover
# Premature client disconnect
pass
await self.writer.close()
await self.stream.close()

async def _send_request(
self, request: AsyncRequest, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -111,7 +109,7 @@ async def _send_event(self, event: H11Event, timeout: TimeoutConfig = None) -> N
drain before returning.
"""
bytes_to_send = self.h11_state.send(event)
await self.writer.write(bytes_to_send, timeout)
await self.stream.write(bytes_to_send, timeout)

async def _receive_response(
self, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -154,7 +152,7 @@ async def _receive_event(self, timeout: TimeoutConfig = None) -> H11Event:
event = self.h11_state.next_event()
if event is h11.NEED_DATA:
try:
data = await self.reader.read(
data = await self.stream.read(
self.READ_NUM_BYTES, timeout, flag=self.timeout_flag
)
except OSError: # pragma: nocover
Expand Down Expand Up @@ -184,4 +182,4 @@ def is_closed(self) -> bool:
return self.h11_state.our_state in (h11.CLOSED, h11.ERROR)

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()
24 changes: 11 additions & 13 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import h2.connection
import h2.events

from ..concurrency.base import BaseReader, BaseWriter, ConcurrencyBackend, TimeoutFlag
from ..concurrency.base import BaseStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..models import AsyncRequest, AsyncResponse

Expand All @@ -14,13 +14,11 @@ class HTTP2Connection:

def __init__(
self,
reader: BaseReader,
writer: BaseWriter,
stream: BaseStream,
backend: ConcurrencyBackend,
on_release: typing.Callable = None,
):
self.reader = reader
self.writer = writer
self.stream = stream
self.backend = backend
self.on_release = on_release
self.h2_state = h2.connection.H2Connection()
Expand Down Expand Up @@ -58,12 +56,12 @@ async def send(
)

async def close(self) -> None:
await self.writer.close()
await self.stream.close()

def initiate_connection(self) -> None:
self.h2_state.initiate_connection()
data_to_send = self.h2_state.data_to_send()
self.writer.write_no_block(data_to_send)
self.stream.write_no_block(data_to_send)
self.initialized = True

async def send_headers(
Expand All @@ -78,7 +76,7 @@ async def send_headers(
] + [(k, v) for k, v in request.headers.raw if k != b"host"]
self.h2_state.send_headers(stream_id, headers)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)
return stream_id

async def send_request_data(
Expand All @@ -104,12 +102,12 @@ async def send_data(
chunk = data[idx : idx + chunk_size]
self.h2_state.send_data(stream_id, chunk)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

async def end_stream(self, stream_id: int, timeout: TimeoutConfig = None) -> None:
self.h2_state.end_stream(stream_id)
data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

async def receive_response(
self, stream_id: int, timeout: TimeoutConfig = None
Expand Down Expand Up @@ -150,14 +148,14 @@ async def receive_event(
) -> h2.events.Event:
while not self.events[stream_id]:
flag = self.timeout_flags[stream_id]
data = await self.reader.read(self.READ_NUM_BYTES, timeout, flag=flag)
data = await self.stream.read(self.READ_NUM_BYTES, timeout, flag=flag)
events = self.h2_state.receive_data(data)
for event in events:
if getattr(event, "stream_id", 0):
self.events[event.stream_id].append(event)

data_to_send = self.h2_state.data_to_send()
await self.writer.write(data_to_send, timeout)
await self.stream.write(data_to_send, timeout)

return self.events[stream_id].pop(0)

Expand All @@ -173,4 +171,4 @@ def is_closed(self) -> bool:
return False

def is_connection_dropped(self) -> bool:
return self.reader.is_connection_dropped()
return self.stream.is_connection_dropped()
Loading