From e46294a99c9f095dbfab178700ff79853e98c7ef Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Sat, 20 Jan 2024 11:59:39 -0800 Subject: [PATCH 1/6] Tightening the runtime type check for ssl (#7698) Currently, the valid types of ssl parameter are SSLContext, Literal[False], Fingerprint or None. If user sets ssl = False, we disable ssl certificate validation which makes total sense. But if user set ssl = True by mistake, instead of enabling ssl certificate validation or raising errors, we silently disable the validation too which is a little subtle but weird. In this PR, we added a check that if user sets ssl=True, we enable certificate validation by treating it as using Default SSL Context. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sviatoslav Sydorenko Co-authored-by: Sam Bull Co-authored-by: J. Nick Koston Co-authored-by: Sam Bull (cherry picked from commit 9e14ea19b5a48bb26797babc32202605066cb5f5) --- CHANGES/7698.feature | 1 + aiohttp/client.py | 22 +++++++++++++++++----- aiohttp/client_exceptions.py | 6 +++--- aiohttp/client_reqrep.py | 19 +++++++++---------- aiohttp/connector.py | 11 ++++++++--- tests/test_client_exceptions.py | 10 +++++----- tests/test_client_request.py | 4 ++-- tests/test_connector.py | 16 ++++++++-------- tests/test_proxy.py | 4 ++-- 9 files changed, 55 insertions(+), 38 deletions(-) create mode 100644 CHANGES/7698.feature diff --git a/CHANGES/7698.feature b/CHANGES/7698.feature new file mode 100644 index 00000000000..e8c4b3fb452 --- /dev/null +++ b/CHANGES/7698.feature @@ -0,0 +1 @@ +Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99` diff --git a/aiohttp/client.py b/aiohttp/client.py index d08211bd00e..8d91fbc1550 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -22,7 +22,6 @@ Generic, Iterable, List, - Literal, Mapping, Optional, Set, @@ -415,7 +414,7 @@ async def _request( verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, - ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, server_hostname: Optional[str] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, @@ -432,6 +431,11 @@ async def _request( if self.closed: raise RuntimeError("Session is closed") + if not isinstance(ssl, SSL_ALLOWED_TYPES): + raise TypeError( + "ssl should be SSLContext, Fingerprint, or bool, " + "got {!r} instead.".format(ssl) + ) ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if data is not None and json is not None: @@ -571,7 +575,7 @@ async def _request( proxy_auth=proxy_auth, timer=timer, session=self, - ssl=ssl, + ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr] server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, @@ -752,7 +756,7 @@ def ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -804,7 +808,7 @@ async def _ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, verify_ssl: Optional[bool] = None, fingerprint: Optional[bytes] = None, ssl_context: Optional[SSLContext] = None, @@ -838,6 +842,14 @@ async def _ws_connect( extstr = ws_ext_gen(compress=compress) real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + # For the sake of backward compatibility, if user passes in None, convert it to True + if ssl is None: + warnings.warn( + "ssl=None is deprecated, please use ssl=True", + DeprecationWarning, + stacklevel=2, + ) + ssl = True ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) # send request diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index d70988f6ede..60bf058e887 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -182,12 +182,12 @@ def port(self) -> Optional[int]: return self._conn_key.port @property - def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]: + def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]: return self._conn_key.ssl def __str__(self) -> str: return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( - self, self.ssl if self.ssl is not None else "default", self.strerror + self, "default" if self.ssl is True else self.ssl, self.strerror ) # OSError.__reduce__ does too much black magick @@ -221,7 +221,7 @@ def path(self) -> str: def __str__(self) -> str: return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format( - self, self.ssl if self.ssl is not None else "default", self.strerror + self, "default" if self.ssl is True else self.ssl, self.strerror ) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 4ae0ecbcdfb..fca3549c2c7 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -17,7 +17,6 @@ Dict, Iterable, List, - Literal, Mapping, Optional, Tuple, @@ -151,11 +150,11 @@ def check(self, transport: asyncio.Transport) -> None: if ssl is not None: SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None)) else: # pragma: no cover - SSL_ALLOWED_TYPES = type(None) + SSL_ALLOWED_TYPES = (bool, type(None)) def _merge_ssl_params( - ssl: Union["SSLContext", Literal[False], Fingerprint, None], + ssl: Union["SSLContext", bool, Fingerprint, None], verify_ssl: Optional[bool], ssl_context: Optional["SSLContext"], fingerprint: Optional[bytes], @@ -166,7 +165,7 @@ def _merge_ssl_params( DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -179,7 +178,7 @@ def _merge_ssl_params( DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -192,7 +191,7 @@ def _merge_ssl_params( DeprecationWarning, stacklevel=3, ) - if ssl is not None: + if ssl is not True: raise ValueError( "verify_ssl, ssl_context, fingerprint and ssl " "parameters are mutually exclusive" @@ -214,7 +213,7 @@ class ConnectionKey: host: str port: Optional[int] is_ssl: bool - ssl: Union[SSLContext, None, Literal[False], Fingerprint] + ssl: Union[SSLContext, bool, Fingerprint] proxy: Optional[URL] proxy_auth: Optional[BasicAuth] proxy_headers_hash: Optional[int] # hash(CIMultiDict) @@ -276,7 +275,7 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, timer: Optional[BaseTimerContext] = None, session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, trust_env: bool = False, @@ -315,7 +314,7 @@ def __init__( real_response_class = response_class self.response_class: Type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() - self._ssl = ssl + self._ssl = ssl if ssl is not None else True # type: ignore[redundant-expr] self.server_hostname = server_hostname if loop.get_debug(): @@ -357,7 +356,7 @@ def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @property - def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]: + def ssl(self) -> Union["SSLContext", bool, Fingerprint]: return self._ssl @property diff --git a/aiohttp/connector.py b/aiohttp/connector.py index baa3a7170f6..cf620f6ca6e 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -768,7 +768,7 @@ def __init__( ttl_dns_cache: Optional[int] = 10, family: int = 0, ssl_context: Optional[SSLContext] = None, - ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None, + ssl: Union[bool, Fingerprint, SSLContext] = True, local_addr: Optional[Tuple[str, int]] = None, resolver: Optional[AbstractResolver] = None, keepalive_timeout: Union[None, float, object] = sentinel, @@ -791,6 +791,11 @@ def __init__( timeout_ceil_threshold=timeout_ceil_threshold, ) + if not isinstance(ssl, SSL_ALLOWED_TYPES): + raise TypeError( + "ssl should be SSLContext, Fingerprint, or bool, " + "got {!r} instead.".format(ssl) + ) self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if resolver is None: resolver = DefaultResolver(loop=self._loop) @@ -965,13 +970,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: sslcontext = req.ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext - if sslcontext is not None: + if sslcontext is not True: # not verified or fingerprinted return self._make_ssl_context(False) sslcontext = self._ssl if isinstance(sslcontext, ssl.SSLContext): return sslcontext - if sslcontext is not None: + if sslcontext is not True: # not verified or fingerprinted return self._make_ssl_context(False) return self._make_ssl_context(True) diff --git a/tests/test_client_exceptions.py b/tests/test_client_exceptions.py index 8f34e4cc73c..f70ba5d09a6 100644 --- a/tests/test_client_exceptions.py +++ b/tests/test_client_exceptions.py @@ -119,7 +119,7 @@ class TestClientConnectorError: host="example.com", port=8080, is_ssl=False, - ssl=None, + ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, @@ -136,7 +136,7 @@ def test_ctor(self) -> None: assert err.os_error.strerror == "No such file" assert err.host == "example.com" assert err.port == 8080 - assert err.ssl is None + assert err.ssl is True def test_pickle(self) -> None: err = client.ClientConnectorError( @@ -153,7 +153,7 @@ def test_pickle(self) -> None: assert err2.os_error.strerror == "No such file" assert err2.host == "example.com" assert err2.port == 8080 - assert err2.ssl is None + assert err2.ssl is True assert err2.foo == "bar" def test_repr(self) -> None: @@ -171,7 +171,7 @@ def test_str(self) -> None: os_error=OSError(errno.ENOENT, "No such file"), ) assert str(err) == ( - "Cannot connect to host example.com:8080 ssl:" "default [No such file]" + "Cannot connect to host example.com:8080 ssl:default [No such file]" ) @@ -180,7 +180,7 @@ class TestClientConnectorCertificateError: host="example.com", port=8080, is_ssl=False, - ssl=None, + ssl=True, proxy=None, proxy_auth=None, proxy_headers_hash=None, diff --git a/tests/test_client_request.py b/tests/test_client_request.py index c8ce98d4034..6521b70ad55 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -166,7 +166,7 @@ def test_host_port_default_http(make_request) -> None: req = make_request("get", "http://python.org/") assert req.host == "python.org" assert req.port == 80 - assert not req.ssl + assert not req.is_ssl() def test_host_port_default_https(make_request) -> None: @@ -400,7 +400,7 @@ def test_ipv6_default_http_port(make_request) -> None: req = make_request("get", "http://[2001:db8::1]/") assert req.host == "2001:db8::1" assert req.port == 80 - assert not req.ssl + assert not req.is_ssl() def test_ipv6_default_https_port(make_request) -> None: diff --git a/tests/test_connector.py b/tests/test_connector.py index 1faec002487..84c03fc6fb5 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -30,19 +30,19 @@ @pytest.fixture() def key(): # Connection key - return ConnectionKey("localhost", 80, False, None, None, None, None) + return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def key2(): # Connection key - return ConnectionKey("localhost", 80, False, None, None, None, None) + return ConnectionKey("localhost", 80, False, True, None, None, None) @pytest.fixture def ssl_key(): # Connection key - return ConnectionKey("localhost", 80, True, None, None, None, None) + return ConnectionKey("localhost", 80, True, True, None, None, None) @pytest.fixture @@ -1467,9 +1467,9 @@ async def test_cleanup_closed_disabled(loop, mocker) -> None: assert not conn._cleanup_closed_transports -async def test_tcp_connector_ctor(loop) -> None: - conn = aiohttp.TCPConnector(loop=loop) - assert conn._ssl is None +async def test_tcp_connector_ctor() -> None: + conn = aiohttp.TCPConnector() + assert conn._ssl is True assert conn.use_dns_cache assert conn.family == 0 @@ -1555,7 +1555,7 @@ async def test___get_ssl_context3(loop) -> None: conn = aiohttp.TCPConnector(loop=loop, ssl=ctx) req = mock.Mock() req.is_ssl.return_value = True - req.ssl = None + req.ssl = True assert conn._get_ssl_context(req) is ctx @@ -1581,7 +1581,7 @@ async def test___get_ssl_context6(loop) -> None: conn = aiohttp.TCPConnector(loop=loop) req = mock.Mock() req.is_ssl.return_value = True - req.ssl = None + req.ssl = True assert conn._get_ssl_context(req) is conn._make_ssl_context(True) diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 2a8643f5047..f335e42c254 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -92,7 +92,7 @@ async def make_conn(): auth=None, headers={"Host": "www.python.org"}, loop=self.loop, - ssl=None, + ssl=True, ) conn.close() @@ -150,7 +150,7 @@ async def make_conn(): auth=None, headers={"Host": "www.python.org", "Foo": "Bar"}, loop=self.loop, - ssl=None, + ssl=True, ) conn.close() From ed86b27b0d0259d240c1a9aee2213a94ea902412 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 20 Jan 2024 21:10:19 +0000 Subject: [PATCH 2/6] Apply suggestions from code review --- aiohttp/client.py | 5 ----- aiohttp/connector.py | 5 ----- 2 files changed, 10 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 8d91fbc1550..998396c3edc 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -431,11 +431,6 @@ async def _request( if self.closed: raise RuntimeError("Session is closed") - if not isinstance(ssl, SSL_ALLOWED_TYPES): - raise TypeError( - "ssl should be SSLContext, Fingerprint, or bool, " - "got {!r} instead.".format(ssl) - ) ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if data is not None and json is not None: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index cf620f6ca6e..d0954355244 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -791,11 +791,6 @@ def __init__( timeout_ceil_threshold=timeout_ceil_threshold, ) - if not isinstance(ssl, SSL_ALLOWED_TYPES): - raise TypeError( - "ssl should be SSLContext, Fingerprint, or bool, " - "got {!r} instead.".format(ssl) - ) self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) if resolver is None: resolver = DefaultResolver(loop=self._loop) From 37ac9cf66bb81584e2e09a832ee8edfb18c663f2 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 20 Jan 2024 21:46:51 +0000 Subject: [PATCH 3/6] Update client_reqrep.py --- aiohttp/client_reqrep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index fca3549c2c7..9037d656d4b 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -158,7 +158,7 @@ def _merge_ssl_params( verify_ssl: Optional[bool], ssl_context: Optional["SSLContext"], fingerprint: Optional[bytes], -) -> Union["SSLContext", Literal[False], Fingerprint, None]: +) -> Union["SSLContext", bool, Fingerprint]: if verify_ssl is not None and not verify_ssl: warnings.warn( "verify_ssl is deprecated, use ssl=False instead", From 30cc72467c54230af956a5a43ec4604fe06c42ad Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 20 Jan 2024 21:49:36 +0000 Subject: [PATCH 4/6] Update client_reqrep.py --- aiohttp/client_reqrep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 9037d656d4b..efbc8767c0c 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -154,7 +154,7 @@ def check(self, transport: asyncio.Transport) -> None: def _merge_ssl_params( - ssl: Union["SSLContext", bool, Fingerprint, None], + ssl: Union["SSLContext", bool, Fingerprint], verify_ssl: Optional[bool], ssl_context: Optional["SSLContext"], fingerprint: Optional[bytes], From 045ec4be33f3b89ae3939e0c2c67f2688f809f4b Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 20 Jan 2024 22:22:54 +0000 Subject: [PATCH 5/6] Apply suggestions from code review --- aiohttp/client.py | 2 +- aiohttp/client_reqrep.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 998396c3edc..36dbf6a7119 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -570,7 +570,7 @@ async def _request( proxy_auth=proxy_auth, timer=timer, session=self, - ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr] + ssl=ssl if ssl is not None else True, server_hostname=server_hostname, proxy_headers=proxy_headers, traces=traces, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index efbc8767c0c..bb43ae9318d 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -314,7 +314,7 @@ def __init__( real_response_class = response_class self.response_class: Type[ClientResponse] = real_response_class self._timer = timer if timer is not None else TimerNoop() - self._ssl = ssl if ssl is not None else True # type: ignore[redundant-expr] + self._ssl = ssl if ssl is not None else True self.server_hostname = server_hostname if loop.get_debug(): From 47ed9fe25ce800dc4a8d3909fd35ecb6acafe3c3 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Sat, 20 Jan 2024 22:32:24 +0000 Subject: [PATCH 6/6] Update test_client_fingerprint.py --- tests/test_client_fingerprint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_client_fingerprint.py b/tests/test_client_fingerprint.py index b1ae3cae36e..68dd528e0a2 100644 --- a/tests/test_client_fingerprint.py +++ b/tests/test_client_fingerprint.py @@ -37,7 +37,7 @@ def test_fingerprint_check_no_ssl() -> None: def test__merge_ssl_params_verify_ssl() -> None: with pytest.warns(DeprecationWarning): - assert _merge_ssl_params(None, False, None, None) is False + assert _merge_ssl_params(True, False, None, None) is False def test__merge_ssl_params_verify_ssl_conflict() -> None: @@ -50,7 +50,7 @@ def test__merge_ssl_params_verify_ssl_conflict() -> None: def test__merge_ssl_params_ssl_context() -> None: ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) with pytest.warns(DeprecationWarning): - assert _merge_ssl_params(None, None, ctx, None) is ctx + assert _merge_ssl_params(True, None, ctx, None) is ctx def test__merge_ssl_params_ssl_context_conflict() -> None: @@ -64,7 +64,7 @@ def test__merge_ssl_params_ssl_context_conflict() -> None: def test__merge_ssl_params_fingerprint() -> None: digest = hashlib.sha256(b"123").digest() with pytest.warns(DeprecationWarning): - ret = _merge_ssl_params(None, None, None, digest) + ret = _merge_ssl_params(True, None, None, digest) assert ret.fingerprint == digest