From bcbf4ba3d60c2476d72f4e16915891b1cbff4d2c Mon Sep 17 00:00:00 2001 From: Ian Stapleton Cordasco Date: Sun, 3 Mar 2024 07:00:49 -0600 Subject: [PATCH] Use TLS settings in selecting connection pool Previously, if someone made a request with `verify=False` then made a request where they expected verification to be enabled to the same host, they would potentially reuse a connection where TLS had not been verified. This fixes that issue. --- src/requests/adapters.py | 57 +++++++++++++++++++++++++++++++++++++++- tests/test_requests.py | 7 +++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/requests/adapters.py b/src/requests/adapters.py index fc5606bdcb..a0ee292a77 100644 --- a/src/requests/adapters.py +++ b/src/requests/adapters.py @@ -8,6 +8,7 @@ import os.path import socket # noqa: F401 +import typing from urllib3.exceptions import ClosedPoolError, ConnectTimeoutError from urllib3.exceptions import HTTPError as _HTTPError @@ -61,12 +62,44 @@ def SOCKSProxyManager(*args, **kwargs): raise InvalidSchema("Missing dependencies for SOCKS support.") +if typing.TYPE_CHECKING: + from .models import PreparedRequest + + DEFAULT_POOLBLOCK = False DEFAULT_POOLSIZE = 10 DEFAULT_RETRIES = 0 DEFAULT_POOL_TIMEOUT = None +def _urllib3_request_context( + request: "PreparedRequest", verify: "bool | str | None" +) -> "typing.Dict[str, typing.Any]": + context = {} + parsed_request_url = urlparse(request.url) + scheme = parsed_request_url.scheme.lower() + # In case the URL scheme is not entirely lower-case, we need to normalize + # that for our dictionary but also for urllib3 + default_ports = {"http": 80, "https": 443} + port = parsed_request_url.port + if port is None: + port = default_ports.get(scheme) + cert_reqs = "CERT_REQUIRED" + if verify is False: + cert_reqs = "CERT_NONE" + if isinstance(verify, str): + context["ca_certs"] = verify + context.update( + { + "scheme": scheme, + "host": parsed_request_url.hostname, + "port": port, + "cert_reqs": cert_reqs, + } + ) + return context + + class BaseAdapter: """The Base Transport Adapter""" @@ -327,6 +360,28 @@ def build_response(self, req, resp): return response + def _get_connection(self, request, verify, proxies=None): + # Replace the existing get_connection without breaking things and + # ensure that TLS settings are considered when we interact with + # urllib3 HTTP Pools + proxy = select_proxy(request.url, proxies) + urllib3_context = _urllib3_request_context(request, verify) + if proxy: + proxy = prepend_scheme_if_needed(proxy, "http") + proxy_url = parse_url(proxy) + if not proxy_url.host: + raise InvalidProxyURL( + "Please check proxy URL. It is malformed " + "and could be missing the host." + ) + proxy_manager = self.proxy_manager_for(proxy) + conn = proxy_manager.connection_from_context(urllib3_context) + else: + # Only scheme should be lower case + conn = self.poolmanager.connection_from_context(urllib3_context) + + return conn + def get_connection(self, url, proxies=None): """Returns a urllib3 connection for the given URL. This should not be called from user code, and is only exposed for use when subclassing the @@ -453,7 +508,7 @@ def send( """ try: - conn = self.get_connection(request.url, proxies) + conn = self._get_connection(request, verify, proxies) except LocationValueError as e: raise InvalidURL(e, request=request) diff --git a/tests/test_requests.py b/tests/test_requests.py index 32b5e6700c..d5cc13c79f 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -2828,6 +2828,13 @@ def test_status_code_425(self): assert r5 == 425 assert r6 == 425 + def test_different_connection_pool_for_tls_settings(self): + s = requests.Session() + r1 = s.get("https://invalid.badssl.com", verify=False) + assert r1.status_code == 421 + with pytest.raises(requests.exceptions.SSLError): + s.get("https://invalid.badssl.com") + def test_json_decode_errors_are_serializable_deserializable(): json_decode_error = requests.exceptions.JSONDecodeError(