diff --git a/py/selenium/webdriver/remote/client_config.py b/py/selenium/webdriver/remote/client_config.py index 3b23c97edea62..8ab0c9ec44c14 100644 --- a/py/selenium/webdriver/remote/client_config.py +++ b/py/selenium/webdriver/remote/client_config.py @@ -16,9 +16,12 @@ # under the License. import base64 import os +import socket from typing import Optional from urllib import parse +import certifi + from selenium.webdriver.common.proxy import Proxy from selenium.webdriver.common.proxy import ProxyType @@ -27,8 +30,12 @@ class ClientConfig: def __init__( self, remote_server_addr: str, - keep_alive: bool = True, - proxy: Proxy = Proxy(raw={"proxyType": ProxyType.SYSTEM}), + keep_alive: Optional[bool] = True, + proxy: Optional[Proxy] = Proxy(raw={"proxyType": ProxyType.SYSTEM}), + ignore_certificates: Optional[bool] = False, + init_args_for_pool_manager: Optional[dict] = None, + timeout: Optional[int] = None, + ca_certs: Optional[str] = None, username: Optional[str] = None, password: Optional[str] = None, auth_type: Optional[str] = "Basic", @@ -37,17 +44,38 @@ def __init__( self.remote_server_addr = remote_server_addr self.keep_alive = keep_alive self.proxy = proxy + self.ignore_certificates = ignore_certificates + self.init_args_for_pool_manager = init_args_for_pool_manager or {} + self.timeout = timeout self.username = username self.password = password self.auth_type = auth_type self.token = token + self.timeout = ( + ( + float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout()))) + if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None + else socket.getdefaulttimeout() + ) + if timeout is None + else timeout + ) + + self.ca_certs = ( + (os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where()) + if ca_certs is None + else ca_certs + ) + @property def remote_server_addr(self) -> str: + """:Returns: The address of the remote server.""" return self._remote_server_addr @remote_server_addr.setter def remote_server_addr(self, value: str) -> None: + """Provides the address of the remote server.""" self._remote_server_addr = value @property @@ -73,45 +101,125 @@ def proxy(self) -> Proxy: def proxy(self, proxy: Proxy) -> None: """Provides the information for communicating with the driver or server. + For example: Proxy(raw={"proxyType": ProxyType.SYSTEM}) :Args: - value: the proxy information to use to communicate with the driver or server """ self._proxy = proxy + @property + def ignore_certificates(self) -> bool: + """:Returns: The ignore certificate check value""" + return self._ignore_certificates + + @ignore_certificates.setter + def ignore_certificates(self, ignore_certificates: bool) -> None: + """Toggles the ignore certificate check. + + :Args: + - value: value of ignore certificate check + """ + self._ignore_certificates = ignore_certificates + + @property + def init_args_for_pool_manager(self) -> dict: + """:Returns: The dictionary of arguments will be appended while initializing the pool manager.""" + return self._init_args_for_pool_manager + + @init_args_for_pool_manager.setter + def init_args_for_pool_manager(self, init_args_for_pool_manager: dict) -> None: + """Provides dictionary of arguments will be appended while initializing the pool manager. + For example: {"init_args_for_pool_manager": {"retries": 3, "block": True}} + + :Args: + - value: the dictionary of arguments will be appended while initializing the pool manager + """ + self._init_args_for_pool_manager = init_args_for_pool_manager + + @property + def timeout(self) -> int: + """:Returns: The timeout (in seconds) used for communicating to the + driver/server.""" + return self._timeout + + @timeout.setter + def timeout(self, timeout: int) -> None: + """Provides the timeout (in seconds) for communicating with the driver + or server. + + :Args: + - value: the timeout (in seconds) to use to communicate with the driver or server + """ + self._timeout = timeout + + def reset_timeout(self) -> None: + """Resets the timeout to the default value of socket.""" + self._timeout = socket.getdefaulttimeout() + + @property + def ca_certs(self) -> str: + """:Returns: The path to bundle of CA certificates.""" + return self._ca_certs + + @ca_certs.setter + def ca_certs(self, ca_certs: str) -> None: + """Provides the path to bundle of CA certificates for establishing + secure connections. + + :Args: + - value: the path to bundle of CA certificates for establishing secure connections + """ + self._ca_certs = ca_certs + @property def username(self) -> str: + """Returns the username used for basic authentication to the remote + server.""" return self._username @username.setter def username(self, value: str) -> None: + """Sets the username used for basic authentication to the remote + server.""" self._username = value @property def password(self) -> str: + """Returns the password used for basic authentication to the remote + server.""" return self._password @password.setter def password(self, value: str) -> None: + """Sets the password used for basic authentication to the remote + server.""" self._password = value @property def auth_type(self) -> str: + """Returns the type of authentication to the remote server.""" return self._auth_type @auth_type.setter def auth_type(self, value: str) -> None: + """Sets the type of authentication to the remote server if it is not + using basic with username and password.""" self._auth_type = value @property def token(self) -> str: + """Returns the token used for authentication to the remote server.""" return self._token @token.setter def token(self, value: str) -> None: + """Sets the token used for authentication to the remote server if + auth_type is not basic.""" self._token = value def get_proxy_url(self) -> Optional[str]: + """Returns the proxy URL to use for the connection.""" proxy_type = self.proxy.proxy_type remote_add = parse.urlparse(self.remote_server_addr) if proxy_type is ProxyType.DIRECT: @@ -136,6 +244,7 @@ def get_proxy_url(self) -> Optional[str]: return None def get_auth_header(self) -> Optional[dict]: + """Returns the authorization to add to the request headers.""" auth_type = self.auth_type.lower() if auth_type == "basic" and self.username and self.password: credentials = f"{self.username}:{self.password}" diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 78a8bc9438e42..0461d5f5afdcb 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -16,16 +16,13 @@ # under the License. import logging -import os import platform -import socket import string import warnings from base64 import b64encode from typing import Optional from urllib import parse -import certifi import urllib3 from selenium import __version__ @@ -139,12 +136,7 @@ class RemoteConnection: """ browser_name = None - _timeout = ( - float(os.getenv("GLOBAL_DEFAULT_TIMEOUT", str(socket.getdefaulttimeout()))) - if os.getenv("GLOBAL_DEFAULT_TIMEOUT") is not None - else socket.getdefaulttimeout() - ) - _ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where() + _client_config: ClientConfig = None system = platform.system().lower() if system == "darwin": @@ -161,7 +153,12 @@ def get_timeout(cls): Timeout value in seconds for all http requests made to the Remote Connection """ - return None if cls._timeout == socket._GLOBAL_DEFAULT_TIMEOUT else cls._timeout + warnings.warn( + "get_timeout is deprecated, get timeout from ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + return cls._client_config.timeout @classmethod def set_timeout(cls, timeout): @@ -170,12 +167,22 @@ def set_timeout(cls, timeout): :Args: - timeout - timeout value for http requests in seconds """ - cls._timeout = timeout + warnings.warn( + "set_timeout is deprecated, set timeout in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.timeout = timeout @classmethod def reset_timeout(cls): """Reset the http request timeout to socket._GLOBAL_DEFAULT_TIMEOUT.""" - cls._timeout = socket._GLOBAL_DEFAULT_TIMEOUT + warnings.warn( + "reset_timeout is deprecated, use reset_timeout in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.reset_timeout() @classmethod def get_certificate_bundle_path(cls): @@ -185,7 +192,12 @@ def get_certificate_bundle_path(cls): command executor. Defaults to certifi.where() or REQUESTS_CA_BUNDLE env variable if set. """ - return cls._ca_certs + warnings.warn( + "get_certificate_bundle_path is deprecated, get certificate bundle path from ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + return cls._client_config.ca_certs @classmethod def set_certificate_bundle_path(cls, path): @@ -196,7 +208,12 @@ def set_certificate_bundle_path(cls, path): :Args: - path - path of a .pem encoded certificate chain. """ - cls._ca_certs = path + warnings.warn( + "set_certificate_bundle_path is deprecated, set certificate bundle path in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + cls._client_config.ca_certs = path @classmethod def get_remote_connection_headers(cls, parsed_url, keep_alive=False): @@ -239,15 +256,17 @@ def _separate_http_proxy_auth(self): return proxy_without_auth, auth def _get_connection_manager(self): - pool_manager_init_args = {"timeout": self.get_timeout()} - pool_manager_init_args.update(self._init_args_for_pool_manager.get("init_args_for_pool_manager", {})) + pool_manager_init_args = {"timeout": self._client_config.timeout} + pool_manager_init_args.update( + self._client_config.init_args_for_pool_manager.get("init_args_for_pool_manager", {}) + ) - if self._ignore_certificates: + if self._client_config.ignore_certificates: pool_manager_init_args["cert_reqs"] = "CERT_NONE" urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - elif self._ca_certs: + elif self._client_config.ca_certs: pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED" - pool_manager_init_args["ca_certs"] = self._ca_certs + pool_manager_init_args["ca_certs"] = self._client_config.ca_certs if self._proxy_url: if self._proxy_url.lower().startswith("sock"): @@ -270,11 +289,13 @@ def __init__( init_args_for_pool_manager: Optional[dict] = None, client_config: Optional[ClientConfig] = None, ): - self.keep_alive = keep_alive - self._url = remote_server_addr - self._ignore_certificates = ignore_certificates - self._init_args_for_pool_manager = init_args_for_pool_manager or {} - self._client_config = client_config or ClientConfig(remote_server_addr, keep_alive) + self._client_config = client_config or ClientConfig( + remote_server_addr=remote_server_addr, + keep_alive=keep_alive, + ignore_certificates=ignore_certificates, + init_args_for_pool_manager=init_args_for_pool_manager, + ) + RemoteConnection._client_config = self._client_config if remote_server_addr: warnings.warn( @@ -282,6 +303,9 @@ def __init__( DeprecationWarning, stacklevel=2, ) + self._url = remote_server_addr + else: + self._url = self._client_config.remote_server_addr if not keep_alive: warnings.warn( @@ -290,6 +314,20 @@ def __init__( stacklevel=2, ) + if ignore_certificates: + warnings.warn( + "setting ignore_certificates in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + + if init_args_for_pool_manager: + warnings.warn( + "setting init_args_for_pool_manager in RemoteConnection() is deprecated, set in ClientConfig instance instead", + DeprecationWarning, + stacklevel=2, + ) + if ignore_proxy: warnings.warn( "setting ignore_proxy in RemoteConnection() is deprecated, set in ClientConfig instance instead", diff --git a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py index 225d8c634a814..98e9fbaa1314c 100644 --- a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py +++ b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py @@ -127,9 +127,10 @@ def test_get_connection_manager_without_proxy(mock_proxy_settings_missing): assert isinstance(conn, urllib3.PoolManager) -def test_get_connection_manager_for_certs_and_timeout(monkeypatch): - monkeypatch.setattr(RemoteConnection, "get_timeout", lambda _: 10) # Class state; leaks into subsequent tests. +def test_get_connection_manager_for_certs_and_timeout(): remote_connection = RemoteConnection("http://remote", keep_alive=False) + remote_connection.set_timeout(10) + assert remote_connection.get_timeout() == 10 conn = remote_connection._get_connection_manager() assert conn.connection_pool_kw["timeout"] == 10 assert conn.connection_pool_kw["cert_reqs"] == "CERT_REQUIRED" @@ -306,19 +307,42 @@ def test_register_extra_headers(mock_request, remote_connection): assert headers["Foo"] == "bar" +def test_get_connection_manager_with_timeout_from_client_config(): + client_config = ClientConfig("http://remote", timeout=300) + remote_connection = RemoteConnection(None, client_config=client_config) + conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] == 300 + assert isinstance(conn, urllib3.PoolManager) + + +def test_get_connection_manager_with_ca_certs_from_client_config(): + client_config = ClientConfig("http://remote", ca_certs="/path/to/cacert.pem") + remote_connection = RemoteConnection(None, client_config=client_config) + conn = remote_connection._get_connection_manager() + assert conn.connection_pool_kw["timeout"] is None + assert conn.connection_pool_kw["cert_reqs"] == "CERT_REQUIRED" + assert conn.connection_pool_kw["ca_certs"] == "/path/to/cacert.pem" + assert isinstance(conn, urllib3.PoolManager) + + def test_get_connection_manager_ignores_certificates(monkeypatch): - monkeypatch.setattr(RemoteConnection, "get_timeout", lambda _: 10) - remote_connection = RemoteConnection("http://remote", ignore_certificates=True) + client_config = ClientConfig("http://remote", ignore_certificates=True) + remote_connection = RemoteConnection(None, client_config=client_config) + remote_connection.set_timeout(10) conn = remote_connection._get_connection_manager() assert conn.connection_pool_kw["timeout"] == 10 assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" assert isinstance(conn, urllib3.PoolManager) + remote_connection.reset_timeout() + assert remote_connection.get_timeout() is None + def test_get_connection_manager_with_custom_args(): custom_args = {"init_args_for_pool_manager": {"retries": 3, "block": True}} - remote_connection = RemoteConnection("http://remote", keep_alive=False, init_args_for_pool_manager=custom_args) + client_config = ClientConfig("http://remote", keep_alive=False, init_args_for_pool_manager=custom_args) + remote_connection = RemoteConnection(None, client_config=client_config) conn = remote_connection._get_connection_manager() assert isinstance(conn, urllib3.PoolManager)