Skip to content

Commit

Permalink
Tightening the runtime type check for ssl (#7698)
Browse files Browse the repository at this point in the history
Currently, the valid types of ssl parameter are SSLContext,
Literal[False], Fingerprint or None.

If user sets ssl = False, we disable ssl certificate validation which
makes total sense. But if user set ssl = True by mistake, instead of
enabling ssl certificate validation or raising errors, we silently
disable the validation too which is a little subtle but weird.

In this PR, we added a check that if user sets ssl=True, we enable
certificate validation by treating it as using Default SSL Context.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sviatoslav Sydorenko <wk.cvs.github@sydorenko.org.ua>
Co-authored-by: Sam Bull <aa6bs0@sambull.org>
Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Sam Bull <git@sambull.org>
  • Loading branch information
6 people authored Jan 20, 2024
1 parent 2670e7b commit 9e14ea1
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGES/7698.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99`
28 changes: 18 additions & 10 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -364,7 +363,7 @@ async def _request(
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
Expand All @@ -382,8 +381,8 @@ async def _request(

if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, bool, Fingerprint, "
"or None, got {!r} instead.".format(ssl)
"ssl should be SSLContext, Fingerprint, or bool, "
"got {!r} instead.".format(ssl)
)

if data is not None and json is not None:
Expand Down Expand Up @@ -513,7 +512,7 @@ async def _request(
proxy_auth=proxy_auth,
timer=timer,
session=self,
ssl=ssl,
ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr]
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
Expand Down Expand Up @@ -702,7 +701,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
Expand All @@ -725,7 +724,7 @@ def ws_connect(
headers=headers,
proxy=proxy,
proxy_auth=proxy_auth,
ssl=ssl,
ssl=ssl if ssl is not None else True, # type: ignore[redundant-expr]
server_hostname=server_hostname,
proxy_headers=proxy_headers,
compress=compress,
Expand All @@ -750,7 +749,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
Expand Down Expand Up @@ -806,10 +805,19 @@ async def _ws_connect(
extstr = ws_ext_gen(compress=compress)
real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr

# For the sake of backward compatibility, if user passes in None, convert it to True
if ssl is None:
warnings.warn(
"ssl=None is deprecated, please use ssl=True",
DeprecationWarning,
stacklevel=2,
)
ssl = True

if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, bool, Fingerprint, "
"or None, got {!r} instead.".format(ssl)
"ssl should be SSLContext, Fingerprint, or bool, "
"got {!r} instead.".format(ssl)
)

# send request
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def port(self) -> Optional[int]:
return self._conn_key.port

@property
def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]:
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl

def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)

# OSError.__reduce__ does too much black magick
Expand Down Expand Up @@ -188,7 +188,7 @@ def path(self) -> str:

def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)


Expand Down
11 changes: 5 additions & 6 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -149,7 +148,7 @@ def check(self, transport: asyncio.Transport) -> None:
if ssl is not None:
SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
else: # pragma: no cover
SSL_ALLOWED_TYPES = type(None)
SSL_ALLOWED_TYPES = (bool, type(None))


@dataclasses.dataclass(frozen=True)
Expand All @@ -159,7 +158,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
ssl: Union[SSLContext, bool, Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -213,7 +212,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
real_response_class = response_class
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self._ssl = ssl if ssl is not None else True # type: ignore[redundant-expr]
self.server_hostname = server_hostname

if loop.get_debug():
Expand Down Expand Up @@ -290,7 +289,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
return self._ssl

@property
Expand Down
10 changes: 5 additions & 5 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def __init__(
use_dns_cache: bool = True,
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
ssl: Union[bool, Fingerprint, SSLContext] = True,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, _SENTINEL] = sentinel,
Expand All @@ -769,8 +769,8 @@ def __init__(

if not isinstance(ssl, SSL_ALLOWED_TYPES):
raise TypeError(
"ssl should be SSLContext, bool, Fingerprint, "
"or None, got {!r} instead.".format(ssl)
"ssl should be SSLContext, Fingerprint, or bool, "
"got {!r} instead.".format(ssl)
)
self._ssl = ssl
if resolver is None:
Expand Down Expand Up @@ -942,13 +942,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
return self._make_ssl_context(True)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class TestClientConnectorError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand All @@ -106,7 +106,7 @@ def test_ctor(self) -> None:
assert err.os_error.strerror == "No such file"
assert err.host == "example.com"
assert err.port == 8080
assert err.ssl is None
assert err.ssl is True

def test_pickle(self) -> None:
err = client.ClientConnectorError(
Expand All @@ -123,7 +123,7 @@ def test_pickle(self) -> None:
assert err2.os_error.strerror == "No such file"
assert err2.host == "example.com"
assert err2.port == 8080
assert err2.ssl is None
assert err2.ssl is True
assert err2.foo == "bar"

def test_repr(self) -> None:
Expand All @@ -141,7 +141,7 @@ def test_str(self) -> None:
os_error=OSError(errno.ENOENT, "No such file"),
)
assert str(err) == (
"Cannot connect to host example.com:8080 ssl:" "default [No such file]"
"Cannot connect to host example.com:8080 ssl:default [No such file]"
)


Expand All @@ -150,7 +150,7 @@ class TestClientConnectorCertificateError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_host_port_default_http(make_request: Any) -> None:
req = make_request("get", "http://python.org/")
assert req.host == "python.org"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_host_port_default_https(make_request: Any) -> None:
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_ipv6_default_http_port(make_request: Any) -> None:
req = make_request("get", "http://[2001:db8::1]/")
assert req.host == "2001:db8::1"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_ipv6_default_https_port(make_request: Any) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
@pytest.fixture()
def key():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def key2():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def ssl_key():
# Connection key
return ConnectionKey("localhost", 80, True, None, None, None, None)
return ConnectionKey("localhost", 80, True, True, None, None, None)


@pytest.fixture
Expand Down Expand Up @@ -1478,7 +1478,7 @@ async def test_cleanup_closed_disabled(loop: Any, mocker: Any) -> None:

async def test_tcp_connector_ctor(loop: Any) -> None:
conn = aiohttp.TCPConnector()
assert conn._ssl is None
assert conn._ssl is True

assert conn.use_dns_cache
assert conn.family == 0
Expand Down Expand Up @@ -1565,7 +1565,7 @@ async def test___get_ssl_context3(loop: Any) -> None:
conn = aiohttp.TCPConnector(ssl=ctx)
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is ctx


Expand All @@ -1591,7 +1591,7 @@ async def test___get_ssl_context6(loop: Any) -> None:
conn = aiohttp.TCPConnector()
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down Expand Up @@ -146,7 +146,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org", "Foo": "Bar"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down

0 comments on commit 9e14ea1

Please sign in to comment.