diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 261cd28e12b..ffaaa653022 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -51,14 +51,7 @@ ) from .client_proto import ResponseHandler from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint -from .helpers import ( - _SENTINEL, - ceil_timeout, - is_ip_address, - sentinel, - set_exception, - set_result, -) +from .helpers import _SENTINEL, ceil_timeout, is_ip_address, sentinel, set_result from .locks import EventResultOrError from .resolver import DefaultResolver @@ -729,6 +722,35 @@ def expired(self, key: Tuple[str, int]) -> bool: return self._timestamps[key] + self._ttl < monotonic() +def _make_ssl_context(verified: bool) -> SSLContext: + """Create SSL context. + + This method is not async-friendly and should be called from a thread + because it will load certificates from disk and do other blocking I/O. + """ + if ssl is None: + # No ssl support + return None + if verified: + return ssl.create_default_context() + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + sslcontext.options |= ssl.OP_NO_COMPRESSION + sslcontext.set_default_verify_paths() + return sslcontext + + +# The default SSLContext objects are created at import time +# since they do blocking I/O to load certificates from disk, +# and imports should always be done before the event loop starts +# or in a thread. +_SSL_CONTEXT_VERIFIED = _make_ssl_context(True) +_SSL_CONTEXT_UNVERIFIED = _make_ssl_context(False) + + class TCPConnector(BaseConnector): """TCP connector. @@ -759,7 +781,6 @@ class TCPConnector(BaseConnector): """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) - _made_ssl_context: Dict[bool, "asyncio.Future[SSLContext]"] = {} def __init__( self, @@ -963,25 +984,7 @@ async def _create_connection( return proto - @staticmethod - def _make_ssl_context(verified: bool) -> SSLContext: - """Create SSL context. - - This method is not async-friendly and should be called from a thread - because it will load certificates from disk and do other blocking I/O. - """ - if verified: - return ssl.create_default_context() - sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.check_hostname = False - sslcontext.verify_mode = ssl.CERT_NONE - sslcontext.options |= ssl.OP_NO_COMPRESSION - sslcontext.set_default_verify_paths() - return sslcontext - - async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -1005,35 +1008,14 @@ async def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: return sslcontext if sslcontext is not True: # not verified or fingerprinted - return await self._make_or_get_ssl_context(False) + return _SSL_CONTEXT_UNVERIFIED sslcontext = self._ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext if sslcontext is not True: # not verified or fingerprinted - return await self._make_or_get_ssl_context(False) - return await self._make_or_get_ssl_context(True) - - async def _make_or_get_ssl_context(self, verified: bool) -> SSLContext: - """Create or get cached SSL context.""" - try: - return await self._made_ssl_context[verified] - except KeyError: - loop = self._loop - future = loop.create_future() - self._made_ssl_context[verified] = future - try: - result = await loop.run_in_executor( - None, self._make_ssl_context, verified - ) - # BaseException is used since we might get CancelledError - except BaseException as ex: - del self._made_ssl_context[verified] - set_exception(future, ex) - raise - else: - set_result(future, result) - return result + return _SSL_CONTEXT_UNVERIFIED + return _SSL_CONTEXT_VERIFIED def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl @@ -1120,13 +1102,11 @@ async def _start_tls_connection( ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: """Wrap the raw TCP transport with TLS.""" tls_proto = self._factory() # Create a brand new proto for TLS - - # Safety of the `cast()` call here is based on the fact that - # internally `_get_ssl_context()` only returns `None` when - # `req.is_ssl()` evaluates to `False` which is never gonna happen - # in this code path. Of course, it's rather fragile - # maintainability-wise but this is to be solved separately. - sslcontext = cast(ssl.SSLContext, await self._get_ssl_context(req)) + sslcontext = self._get_ssl_context(req) + if TYPE_CHECKING: + # _start_tls_connection is unreachable in the current code path + # if sslcontext is None. + assert sslcontext is not None try: async with ceil_timeout( @@ -1204,7 +1184,7 @@ async def _create_direct_connection( *, client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.Transport, ResponseHandler]: - sslcontext = await self._get_ssl_context(req) + sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) host = req.url.raw_host diff --git a/tests/test_connector.py b/tests/test_connector.py index d66bc214ed2..fe244be466a 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1,5 +1,4 @@ # Tests of http client with custom Connector - import asyncio import gc import hashlib @@ -9,6 +8,7 @@ import sys import uuid from collections import deque +from concurrent import futures from contextlib import closing, suppress from typing import ( Awaitable, @@ -16,11 +16,11 @@ Dict, Iterator, List, + Literal, NoReturn, Optional, Sequence, Tuple, - Type, ) from unittest import mock @@ -30,11 +30,23 @@ from yarl import URL import aiohttp -from aiohttp import ClientRequest, ClientSession, ClientTimeout, web +from aiohttp import ( + ClientRequest, + ClientSession, + ClientTimeout, + connector as connector_module, + web, +) from aiohttp.abc import ResolveResult from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ConnectionKey -from aiohttp.connector import Connection, TCPConnector, _DNSCacheTable +from aiohttp.connector import ( + _SSL_CONTEXT_UNVERIFIED, + _SSL_CONTEXT_VERIFIED, + Connection, + TCPConnector, + _DNSCacheTable, +) from aiohttp.locks import EventResultOrError from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer from aiohttp.test_utils import make_mocked_coro, unused_port @@ -1710,23 +1722,11 @@ async def test_tcp_connector_clear_dns_cache_bad_args( conn.clear_dns_cache("localhost") -async def test_dont_recreate_ssl_context() -> None: - conn = aiohttp.TCPConnector() - ctx = await conn._make_or_get_ssl_context(True) - assert ctx is await conn._make_or_get_ssl_context(True) - - -async def test_dont_recreate_ssl_context2() -> None: - conn = aiohttp.TCPConnector() - ctx = await conn._make_or_get_ssl_context(False) - assert ctx is await conn._make_or_get_ssl_context(False) - - async def test___get_ssl_context1() -> None: conn = aiohttp.TCPConnector() req = mock.Mock() req.is_ssl.return_value = False - assert await conn._get_ssl_context(req) is None + assert conn._get_ssl_context(req) is None async def test___get_ssl_context2() -> None: @@ -1735,7 +1735,7 @@ async def test___get_ssl_context2() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = ctx - assert await conn._get_ssl_context(req) is ctx + assert conn._get_ssl_context(req) is ctx async def test___get_ssl_context3() -> None: @@ -1744,7 +1744,7 @@ async def test___get_ssl_context3() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn._get_ssl_context(req) is ctx + assert conn._get_ssl_context(req) is ctx async def test___get_ssl_context4() -> None: @@ -1753,9 +1753,7 @@ async def test___get_ssl_context4() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = False - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( - False - ) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED async def test___get_ssl_context5() -> None: @@ -1764,9 +1762,7 @@ async def test___get_ssl_context5() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = aiohttp.Fingerprint(hashlib.sha256(b"1").digest()) - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context( - False - ) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_UNVERIFIED async def test___get_ssl_context6() -> None: @@ -1774,7 +1770,7 @@ async def test___get_ssl_context6() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn._get_ssl_context(req) is await conn._make_or_get_ssl_context(True) + assert conn._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED async def test_ssl_context_once() -> None: @@ -1786,31 +1782,9 @@ async def test_ssl_context_once() -> None: req = mock.Mock() req.is_ssl.return_value = True req.ssl = True - assert await conn1._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert await conn2._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert await conn3._get_ssl_context(req) is await conn1._make_or_get_ssl_context( - True - ) - assert conn1._made_ssl_context is conn2._made_ssl_context is conn3._made_ssl_context - assert True in conn1._made_ssl_context - - -@pytest.mark.parametrize("exception", [OSError, ssl.SSLError, asyncio.CancelledError]) -async def test_ssl_context_creation_raises(exception: Type[BaseException]) -> None: - """Test that we try again if SSLContext creation fails the first time.""" - conn = aiohttp.TCPConnector() - conn._made_ssl_context.clear() - - with mock.patch.object( - conn, "_make_ssl_context", side_effect=exception - ), pytest.raises(exception): - await conn._make_or_get_ssl_context(True) - - assert isinstance(await conn._make_or_get_ssl_context(True), ssl.SSLContext) + assert conn1._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED + assert conn2._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED + assert conn3._get_ssl_context(req) is _SSL_CONTEXT_VERIFIED async def test_close_twice(loop: asyncio.AbstractEventLoop, key: ConnectionKey) -> None: @@ -2977,3 +2951,42 @@ async def allow_connection_and_add_dummy_waiter() -> None: ) await connector.close() + + +def test_connector_multiple_event_loop() -> None: + """Test the connector with multiple event loops.""" + + async def async_connect() -> Literal[True]: + conn = aiohttp.TCPConnector() + loop = asyncio.get_running_loop() + req = ClientRequest("GET", URL("https://127.0.0.1"), loop=loop) + with suppress(aiohttp.ClientConnectorError): + with mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + side_effect=ssl.CertificateError, + ): + await conn.connect(req, [], ClientTimeout()) + return True + + def test_connect() -> Literal[True]: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(async_connect()) + finally: + loop.close() + + with futures.ThreadPoolExecutor() as executor: + res_list = [executor.submit(test_connect) for _ in range(2)] + raw_response_list = [res.result() for res in futures.as_completed(res_list)] + + assert raw_response_list == [True, True] + + +def test_default_ssl_context_creation_without_ssl() -> None: + """Verify _make_ssl_context does not raise when ssl is not available.""" + with mock.patch.object(connector_module, "ssl", None): + assert connector_module._make_ssl_context(False) is None + assert connector_module._make_ssl_context(True) is None diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 3e4131160ae..bc83f163e29 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -9,6 +9,7 @@ import aiohttp from aiohttp.client_reqrep import ClientRequest, ClientResponse +from aiohttp.connector import _SSL_CONTEXT_VERIFIED from aiohttp.helpers import TimerNoop from aiohttp.test_utils import make_mocked_coro @@ -934,9 +935,7 @@ async def make_conn() -> aiohttp.TCPConnector: tls_m.assert_called_with( mock.ANY, mock.ANY, - self.loop.run_until_complete( - connector._make_or_get_ssl_context(True) - ), + _SSL_CONTEXT_VERIFIED, server_hostname="www.python.org", ssl_handshake_timeout=mock.ANY, )