Skip to content

Commit

Permalink
Encapsulate http_version into BaseStream
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Aug 21, 2019
1 parent 14eb81f commit fdb1359
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
27 changes: 14 additions & 13 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def __init__(
self.stream_writer = stream_writer
self.timeout = timeout

def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")

if ssl_object is None:
return "HTTP/1.1"

ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()

return "HTTP/2" if ident == "h2" else "HTTP/1.1"

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> bytes:
Expand Down Expand Up @@ -169,7 +181,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseStream, str]:
) -> BaseStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
Expand All @@ -178,20 +190,9 @@ async def connect(
except asyncio.TimeoutError:
raise ConnectTimeout()

ssl_object = stream_writer.get_extra_info("ssl_object")
if ssl_object is None:
ident = "http/1.1"
else:
ident = ssl_object.selected_alpn_protocol()
if ident is None:
ident = ssl_object.selected_npn_protocol()

stream = Stream(
return Stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
http_version = "HTTP/2" if ident == "h2" else "HTTP/1.1"

return stream, http_version

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
Expand Down
5 changes: 4 additions & 1 deletion httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class BaseStream:
backends, or for stand-alone test cases.
"""

def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
Expand Down Expand Up @@ -110,7 +113,7 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseStream, str]:
) -> BaseStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
Expand Down
8 changes: 4 additions & 4 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ async def connect(
else:
on_release = functools.partial(self.release_func, self)

stream, http_version = await self.backend.connect(
host, port, ssl_context, timeout
)
stream = await self.backend.connect(host, port, ssl_context, timeout)
http_version = stream.get_http_version()

if http_version == "HTTP/2":
self.h2_connection = HTTP2Connection(
stream, self.backend, on_release=on_release
Expand All @@ -96,7 +96,7 @@ async def get_ssl_context(self, ssl: SSLConfig) -> typing.Optional[ssl.SSLContex
if not self.origin.is_ssl:
return None

# Run the SSL loading in a threadpool, since it may makes disk accesses.
# Run the SSL loading in a threadpool, since it may make disk accesses.
return await self.backend.run_in_threadpool(
ssl.load_ssl_context, self.http_versions
)
Expand Down
7 changes: 5 additions & 2 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ async def connect(
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> typing.Tuple[BaseStream, str]:
) -> BaseStream:
self.server = MockHTTP2Server(self.app)
return self.server, "HTTP/2"
return self.server


class MockHTTP2Server(BaseStream):
Expand All @@ -36,6 +36,9 @@ def __init__(self, app):

# Stream interface

def get_http_version(self) -> str:
return "HTTP/2"

async def read(self, n, timeout, flag=None) -> bytes:
await asyncio.sleep(0)
send, self.buffer = self.buffer[:n], self.buffer[n:]
Expand Down

0 comments on commit fdb1359

Please sign in to comment.