diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py new file mode 100644 index 000000000..e1c816f71 --- /dev/null +++ b/google/auth/transport/_mtls_helper.py @@ -0,0 +1,153 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for getting mTLS cert and key, for internal use only.""" + +import json +import logging +from os import path +import re +import subprocess + +CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json" +_CERT_PROVIDER_COMMAND = "cert_provider_command" +_CERT_REGEX = re.compile( + b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL +) + +# support various format of key files, e.g. +# "-----BEGIN PRIVATE KEY-----...", +# "-----BEGIN EC PRIVATE KEY-----...", +# "-----BEGIN RSA PRIVATE KEY-----..." +_KEY_REGEX = re.compile( + b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?", + re.DOTALL, +) + +_LOGGER = logging.getLogger(__name__) + + +def _check_dca_metadata_path(metadata_path): + """Checks for context aware metadata. If it exists, returns the absolute path; + otherwise returns None. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + str: absolute path if exists and None otherwise. + """ + metadata_path = path.expanduser(metadata_path) + if not path.exists(metadata_path): + _LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path) + return None + return metadata_path + + +def _read_dca_metadata_file(metadata_path): + """Loads context aware metadata from the given path. + + Args: + metadata_path (str): context aware metadata path. + + Returns: + Dict[str, str]: The metadata. + + Raises: + ValueError: If failed to parse metadata as JSON. + """ + with open(metadata_path) as f: + metadata = json.load(f) + + return metadata + + +def get_client_ssl_credentials(metadata_json): + """Returns the client side mTLS cert and key. + + Args: + metadata_json (Dict[str, str]): metadata JSON file which contains the cert + provider command. + + Returns: + Tuple[bytes, bytes]: client certificate and key, both in PEM format. + + Raises: + OSError: If the cert provider command failed to run. + RuntimeError: If the cert provider command has a runtime error. + ValueError: If the metadata json file doesn't contain the cert provider + command or if the command doesn't produce both the client certificate + and client key. + """ + # TODO: implement an in-memory cache of cert and key so we don't have to + # run cert provider command every time. + + # Check the cert provider command existence in the metadata json file. + if _CERT_PROVIDER_COMMAND not in metadata_json: + raise ValueError("Cert provider command is not found") + + # Execute the command. It throws OsError in case of system failure. + command = metadata_json[_CERT_PROVIDER_COMMAND] + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + + # Check cert provider command execution error. + if process.returncode != 0: + raise RuntimeError( + "Cert provider command returns non-zero status code %s" % process.returncode + ) + + # Extract certificate (chain) and key. + cert_match = re.findall(_CERT_REGEX, stdout) + if len(cert_match) != 1: + raise ValueError("Client SSL certificate is missing or invalid") + key_match = re.findall(_KEY_REGEX, stdout) + if len(key_match) != 1: + raise ValueError("Client SSL key is missing or invalid") + return cert_match[0], key_match[0] + + +def get_client_cert_and_key(client_cert_callback=None): + """Returns the client side certificate and private key. The function first + tries to get certificate and key from client_cert_callback; if the callback + is None or doesn't provide certificate and key, the function tries application + default SSL credentials. + + Args: + client_cert_callback (Optional[Callable[[], (bool, bytes, bytes)]]): A + callback which returns a bool indicating if the call is successful, + and client certificate bytes and private key bytes both in PEM format. + + Returns: + Tuple[bool, bytes, bytes]: + A boolean indicating if cert and key are obtained, the cert bytes + and key bytes both in PEM format. + + Raises: + OSError: If the cert provider command failed to run. + RuntimeError: If the cert provider command has a runtime error. + ValueError: If the metadata json file doesn't contain the cert provider + command or if the command doesn't produce both the client certificate + and client key. + """ + if client_cert_callback: + return client_cert_callback() + + metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH) + if metadata_path: + metadata = _read_dca_metadata_file(metadata_path) + cert, key = get_client_ssl_credentials(metadata) + return True, cert, key + + return False, None, None diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index 32f59e56b..3d24a551d 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -35,10 +35,14 @@ ) import requests.adapters # pylint: disable=ungrouped-imports import requests.exceptions # pylint: disable=ungrouped-imports +from requests.packages.urllib3.util.ssl_ import ( + create_urllib3_context, +) # pylint: disable=ungrouped-imports import six # pylint: disable=ungrouped-imports from google.auth import exceptions from google.auth import transport +import google.auth.transport._mtls_helper _LOGGER = logging.getLogger(__name__) @@ -182,6 +186,52 @@ def __call__( six.raise_from(new_exc, caught_exc) +class _MutualTlsAdapter(requests.adapters.HTTPAdapter): + """ + A TransportAdapter that enables mutual TLS. + + Args: + cert (bytes): client certificate in PEM format + key (bytes): client private key in PEM format + + Raises: + ImportError: if certifi or pyOpenSSL is not installed + OpenSSL.crypto.Error: if client cert or key is invalid + """ + + def __init__(self, cert, key): + import certifi + from OpenSSL import crypto + import urllib3.contrib.pyopenssl + + urllib3.contrib.pyopenssl.inject_into_urllib3() + + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + + ctx_poolmanager = create_urllib3_context() + ctx_poolmanager.load_verify_locations(cafile=certifi.where()) + ctx_poolmanager._ctx.use_certificate(x509) + ctx_poolmanager._ctx.use_privatekey(pkey) + self._ctx_poolmanager = ctx_poolmanager + + ctx_proxymanager = create_urllib3_context() + ctx_proxymanager.load_verify_locations(cafile=certifi.where()) + ctx_proxymanager._ctx.use_certificate(x509) + ctx_proxymanager._ctx.use_privatekey(pkey) + self._ctx_proxymanager = ctx_proxymanager + + super(_MutualTlsAdapter, self).__init__() + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_poolmanager + super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_proxymanager + return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs) + + class AuthorizedSession(requests.Session): """A Requests Session class with credentials. @@ -198,6 +248,49 @@ class AuthorizedSession(requests.Session): The underlying :meth:`request` implementation handles adding the credentials' headers to the request and refreshing credentials as needed. + This class also supports mutual TLS via :meth:`configure_mtls_channel` + method. This method first tries to load client certificate and private key + using the given client_cert_callabck; if callback is None or fails, it tries + to load application default SSL credentials. Exceptions are raised if there + are problems with the certificate, private key, or the loading process, so + it should be called within a try/except block. + + First we create an :class:`AuthorizedSession` instance and specify the endpoints:: + + regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics' + mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics' + + authed_session = AuthorizedSession(credentials) + + Now we can pass a callback to :meth:`configure_mtls_channel`:: + + def my_cert_callback(): + # some code to load client cert bytes and private key bytes, both in + # PEM format. + some_code_to_load_client_cert_and_key() + if loaded: + return True, cert, key + else: + return False, None, None + + # Always call configure_mtls_channel within a try/except block. + try: + authed_session.configure_mtls_channel(my_cert_callback) + except: + # handle exceptions. + + if authed_session.is_mtls: + response = authed_session.request('GET', mtls_endpoint) + else: + response = authed_session.request('GET', regular_endpoint) + + You can alternatively use application default SSL credentials like this:: + + try: + authed_session.configure_mtls_channel() + except: + # handle exceptions. + Args: credentials (google.auth.credentials.Credentials): The credentials to add to the request. @@ -229,6 +322,7 @@ def __init__( self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts self._refresh_timeout = refresh_timeout + self._is_mtls = False if auth_request is None: auth_request_session = requests.Session() @@ -247,6 +341,40 @@ def __init__( # credentials.refresh). self._auth_request = auth_request + def configure_mtls_channel(self, client_cert_callback=None): + """Configure the client certificate and key for SSL connection. + + If client certificate and key are successfully obtained (from the given + client_cert_callabck or from application default SSL credentials), a + :class:`_MutualTlsAdapter` instance will be mounted to "https://" prefix. + + Args: + client_cert_callabck (Optional[Callable[[], (bool, bytes, bytes)]]): + The optional callback returns a boolean indicating if the call + is successful, and the client certificate and private key bytes + both in PEM format. + If the call is not succesful, application default SSL credentials + will be used. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If client cert or key is invalid. + OSError: If the cert provider command launch fails during the + application default SSL credentials loading process. + RuntimeError: If the cert provider command has a runtime error during + the application default SSL credentials loading process. + ValueError: If the context aware metadata file is malformed or the + cert provider command doesn't produce both client certicate and + key during the application default SSL credentials loading process. + """ + self._is_mtls, cert, key = google.auth.transport._mtls_helper.get_client_cert_and_key( + client_cert_callback + ) + + if self._is_mtls: + mtls_adapter = _MutualTlsAdapter(cert, key) + self.mount("https://", mtls_adapter) + def request( self, method, @@ -361,3 +489,8 @@ def request( ) return response + + @property + def is_mtls(self): + """Indicates if the created SSL channel is mutual TLS.""" + return self._is_mtls diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py index d1905e94e..cc21e773f 100644 --- a/google/auth/transport/urllib3.py +++ b/google/auth/transport/urllib3.py @@ -17,7 +17,7 @@ from __future__ import absolute_import import logging - +import warnings # Certifi is Mozilla's certificate bundle. Urllib3 needs a certificate bundle # to verify HTTPS requests, and certifi is the recommended and most reliable @@ -149,6 +149,39 @@ def _make_default_http(): return urllib3.PoolManager() +def _make_mutual_tls_http(cert, key): + """Create a mutual TLS HTTP connection with the given client cert and key. + See https://github.com/urllib3/urllib3/issues/474#issuecomment-253168415 + + Args: + cert (bytes): client certificate in PEM format + key (bytes): client private key in PEM format + + Returns: + urllib3.PoolManager: Mutual TLS HTTP connection. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If the cert or key is invalid. + """ + import certifi + from OpenSSL import crypto + import urllib3.contrib.pyopenssl + + urllib3.contrib.pyopenssl.inject_into_urllib3() + ctx = urllib3.util.ssl_.create_urllib3_context() + ctx.load_verify_locations(cafile=certifi.where()) + + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + + ctx._ctx.use_certificate(x509) + ctx._ctx.use_privatekey(pkey) + + http = urllib3.PoolManager(ssl_context=ctx) + return http + + class AuthorizedHttp(urllib3.request.RequestMethods): """A urllib3 HTTP class with credentials. @@ -168,6 +201,49 @@ class AuthorizedHttp(urllib3.request.RequestMethods): The underlying :meth:`urlopen` implementation handles adding the credentials' headers to the request and refreshing credentials as needed. + This class also supports mutual TLS via :meth:`configure_mtls_channel` + method. This method first tries to load client certificate and private key + using the given client_cert_callabck; if callback is None or fails, it tries + to load application default SSL credentials. Exceptions are raised if there + are problems with the certificate, private key, or the loading process, so + it should be called within a try/except block. + + First we create an :class:`AuthorizedHttp` instance and specify the endpoints:: + + regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics' + mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics' + + authed_http = AuthorizedHttp(credentials) + + Now we can pass a callback to :meth:`configure_mtls_channel`:: + + def my_cert_callback(): + # some code to load client cert bytes and private key bytes, both in + # PEM format. + some_code_to_load_client_cert_and_key() + if loaded: + return True, cert, key + else: + return False, None, None + + # Always call configure_mtls_channel within a try/except block. + try: + is_mtls = authed_http.configure_mtls_channel(my_cert_callback) + except: + # handle exceptions. + + if is_mtls: + response = authed_http.request('GET', mtls_endpoint) + else: + response = authed_http.request('GET', regular_endpoint) + + You can alternatively use application default SSL credentials like this:: + + try: + is_mtls = authed_http.configure_mtls_channel() + except: + # handle exceptions. + Args: credentials (google.auth.credentials.Credentials): The credentials to add to the request. @@ -189,12 +265,14 @@ def __init__( refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, ): - if http is None: - http = _make_default_http() + self.http = _make_default_http() + self._has_user_provided_http = False + else: + self.http = http + self._has_user_provided_http = True self.credentials = credentials - self.http = http self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts # Request instance used by internal methods (for example, @@ -203,6 +281,51 @@ def __init__( super(AuthorizedHttp, self).__init__() + def configure_mtls_channel(self, client_cert_callabck=None): + """Configures mutual TLS channel using the given client_cert_callabck or + application default SSL credentials. Returns True if the channel is + mutual TLS and False otherwise. Note that the `http` provided in the + constructor will be overwritten. + + Args: + client_cert_callabck (Optional[Callable[[], (bool, bytes, bytes)]]): + The optional callback returns a boolean indicating if the call + is successful, and the client certificate and private key bytes + both in PEM format. + If the call is not succesful, application default SSL credentials + will be used. + + Returns: + True if the channel is mutual TLS and False otherwise. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If client cert or key is invalid. + OSError: If the cert provider command launch fails during the + application default SSL credentials loading process. + RuntimeError: If the cert provider command has a runtime error during + the application default SSL credentials loading process. + ValueError: If the context aware metadata file is malformed or the + cert provider command doesn't produce both client certicate and + key during the application default SSL credentials loading process. + """ + found_cert_key, cert, key = transport._mtls_helper.get_client_cert_and_key( + client_cert_callabck + ) + + if found_cert_key: + self.http = _make_mutual_tls_http(cert, key) + else: + self.http = _make_default_http() + + if self._has_user_provided_http: + self._has_user_provided_http = False + warnings.warn( + "`http` provided in the constructor is overwritten", UserWarning + ) + + return found_cert_key + def urlopen(self, method, url, body=None, headers=None, **kwargs): """Implementation of urllib3's urlopen.""" # pylint: disable=arguments-differ diff --git a/system_tests/noxfile.py b/system_tests/noxfile.py index e37049e52..6280d0e76 100644 --- a/system_tests/noxfile.py +++ b/system_tests/noxfile.py @@ -170,7 +170,8 @@ def configure_cloud_sdk(session, application_default_credentials, project=False) # Test sesssions TEST_DEPENDENCIES = ["pytest", "requests"] -PYTHON_VERSIONS=['2.7', '3.7'] +PYTHON_VERSIONS = ["2.7", "3.7"] + @nox.session(python=PYTHON_VERSIONS) def service_account(session): @@ -297,3 +298,11 @@ def grpc(session): session.install(*TEST_DEPENDENCIES, "google-cloud-pubsub==1.0.0") session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE session.run("pytest", "test_grpc.py") + + +@nox.session(python=PYTHON_VERSIONS) +def mtls_http(session): + session.install(LIBRARY_DIR) + session.install(*TEST_DEPENDENCIES, "pyopenssl") + session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE + session.run("pytest", "test_mtls_http.py") diff --git a/system_tests/test_mtls_http.py b/system_tests/test_mtls_http.py new file mode 100644 index 000000000..e7ea0b242 --- /dev/null +++ b/system_tests/test_mtls_http.py @@ -0,0 +1,71 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from os import path + +import google.auth +import google.auth.credentials +import google.auth.transport.requests +import google.auth.transport.urllib3 + +MTLS_ENDPOINT = "https://pubsub.mtls.googleapis.com/v1/projects/{}/topics" +REGULAR_ENDPOINT = "https://pubsub.googleapis.com/v1/projects/{}/topics" + + +def check_context_aware_metadata(): + metadata_path = path.expanduser("~/.secureConnect/context_aware_metadata.json") + return path.exists(metadata_path) + + +def test_requests(): + credentials, project_id = google.auth.default() + credentials = google.auth.credentials.with_scopes_if_required( + credentials, ["https://www.googleapis.com/auth/pubsub"] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession(credentials) + authed_session.configure_mtls_channel() + + # If the devices has context aware metadata, then a mutual TLS channel is + # supposed to be created. + assert authed_session.is_mtls == check_context_aware_metadata() + + if authed_session.is_mtls: + response = authed_session.get(MTLS_ENDPOINT.format(project_id)) + else: + response = authed_session.get(REGULAR_ENDPOINT.format(project_id)) + + assert response.ok + + +def test_urllib3(): + credentials, project_id = google.auth.default() + credentials = google.auth.credentials.with_scopes_if_required( + credentials, ["https://www.googleapis.com/auth/pubsub"] + ) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp(credentials) + is_mtls = authed_http.configure_mtls_channel() + + # If the devices has context aware metadata, then a mutual TLS channel is + # supposed to be created. + assert is_mtls == check_context_aware_metadata() + + if is_mtls: + response = authed_http.request("GET", MTLS_ENDPOINT.format(project_id)) + else: + response = authed_http.request("GET", REGULAR_ENDPOINT.format(project_id)) + + assert response.status == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 7f9a968b7..cf8a0f9e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import sys import mock import pytest +def pytest_configure(): + """Load public certificate and private key.""" + pytest.data_dir = os.path.join(os.path.dirname(__file__), "data") + + with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh: + pytest.private_key_bytes = fh.read() + + with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh: + pytest.public_cert_bytes = fh.read() + + @pytest.fixture def mock_non_existent_module(monkeypatch): """Mocks a non-existing module in sys.modules. diff --git a/tests/data/context_aware_metadata.json b/tests/data/context_aware_metadata.json new file mode 100644 index 000000000..ec40e783f --- /dev/null +++ b/tests/data/context_aware_metadata.json @@ -0,0 +1,6 @@ +{ + "cert_provider_command":[ + "/opt/google/endpoint-verification/bin/SecureConnectHelper", + "--print_certificate"], + "device_resource_ids":["11111111-1111-1111"] +} diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py new file mode 100644 index 000000000..d14ac4744 --- /dev/null +++ b/tests/transport/test__mtls_helper.py @@ -0,0 +1,233 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re + +import mock +import pytest + +from google.auth.transport import _mtls_helper + +CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} + +CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {} + + +def check_cert_and_key(content, expected_cert, expected_key): + success = True + + cert_match = re.findall(_mtls_helper._CERT_REGEX, content) + success = success and len(cert_match) == 1 and cert_match[0] == expected_cert + + key_match = re.findall(_mtls_helper._KEY_REGEX, content) + success = success and len(key_match) == 1 and key_match[0] == expected_key + + return success + + +class TestCertAndKeyRegex(object): + def test_cert_and_key(self): + # Test single cert and single key + check_cert_and_key( + pytest.public_cert_bytes + pytest.private_key_bytes, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + check_cert_and_key( + pytest.private_key_bytes + pytest.public_cert_bytes, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + # Test cert chain and single key + check_cert_and_key( + pytest.public_cert_bytes + + pytest.public_cert_bytes + + pytest.private_key_bytes, + pytest.public_cert_bytes + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + check_cert_and_key( + pytest.private_key_bytes + + pytest.public_cert_bytes + + pytest.public_cert_bytes, + pytest.public_cert_bytes + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + def test_key(self): + # Create some fake keys for regex check. + KEY = b"""-----BEGIN PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END PRIVATE KEY-----""" + RSA_KEY = b"""-----BEGIN RSA PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END RSA PRIVATE KEY-----""" + EC_KEY = b"""-----BEGIN EC PRIVATE KEY----- + MIIBCgKCAQEA4ej0p7bQ7L/r4rVGUz9RN4VQWoej1Bg1mYWIDYslvKrk1gpj7wZg + /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB + -----END EC PRIVATE KEY-----""" + + check_cert_and_key( + pytest.public_cert_bytes + KEY, pytest.public_cert_bytes, KEY + ) + check_cert_and_key( + pytest.public_cert_bytes + RSA_KEY, pytest.public_cert_bytes, RSA_KEY + ) + check_cert_and_key( + pytest.public_cert_bytes + EC_KEY, pytest.public_cert_bytes, EC_KEY + ) + + +class TestCheckaMetadataPath(object): + def test_success(self): + metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json") + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is not None + + def test_failure(self): + metadata_path = os.path.join(pytest.data_dir, "not_exists.json") + returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) + assert returned_path is None + + +class TestReadMetadataFile(object): + def test_success(self): + metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json") + metadata = _mtls_helper._read_dca_metadata_file(metadata_path) + + assert "cert_provider_command" in metadata + + def test_file_not_json(self): + # read a file which is not json format. + metadata_path = os.path.join(pytest.data_dir, "privatekey.pem") + with pytest.raises(ValueError): + _mtls_helper._read_dca_metadata_file(metadata_path) + + +class TestGetClientSslCredentials(object): + def create_mock_process(self, output, error): + # There are two steps to execute a script with subprocess.Popen. + # (1) process = subprocess.Popen([comannds]) + # (2) stdout, stderr = process.communicate() + # This function creates a mock process which can be returned by a mock + # subprocess.Popen. The mock process returns the given output and error + # when mock_process.communicate() is called. + mock_process = mock.Mock() + attrs = {"communicate.return_value": (output, error), "returncode": 0} + mock_process.configure_mock(**attrs) + return mock_process + + @mock.patch("subprocess.Popen", autospec=True) + def test_success(self, mock_popen): + mock_popen.return_value = self.create_mock_process( + pytest.public_cert_bytes + pytest.private_key_bytes, b"" + ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes + + @mock.patch("subprocess.Popen", autospec=True) + def test_success_with_cert_chain(self, mock_popen): + PUBLIC_CERT_CHAIN_BYTES = pytest.public_cert_bytes + pytest.public_cert_bytes + mock_popen.return_value = self.create_mock_process( + PUBLIC_CERT_CHAIN_BYTES + pytest.private_key_bytes, b"" + ) + cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + assert cert == PUBLIC_CERT_CHAIN_BYTES + assert key == pytest.private_key_bytes + + def test_missing_cert_provider_command(self): + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials( + CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND + ) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_cert(self, mock_popen): + mock_popen.return_value = self.create_mock_process( + pytest.private_key_bytes, b"" + ) + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_missing_key(self, mock_popen): + mock_popen.return_value = self.create_mock_process( + pytest.public_cert_bytes, b"" + ) + with pytest.raises(ValueError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_cert_provider_returns_error(self, mock_popen): + mock_popen.return_value = self.create_mock_process(b"", b"some error") + mock_popen.return_value.returncode = 1 + with pytest.raises(RuntimeError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + @mock.patch("subprocess.Popen", autospec=True) + def test_popen_raise_exception(self, mock_popen): + mock_popen.side_effect = OSError() + with pytest.raises(OSError): + assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + +class TestGetClientCertAndKey(object): + def test_callback_success(self): + callback = mock.Mock() + callback.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key(callback) + assert found_cert_key + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes + + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) + def test_no_metadata(self, mock_check_dca_metadata_path): + mock_check_dca_metadata_path.return_value = None + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key() + assert not found_cert_key + + @mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True + ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) + def test_use_metadata( + self, mock_check_dca_metadata_path, mock_get_client_ssl_credentials + ): + mock_check_dca_metadata_path.return_value = os.path.join( + pytest.data_dir, "context_aware_metadata.json" + ) + mock_get_client_ssl_credentials.return_value = ( + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key() + assert found_cert_key + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 9aafd88b1..46d9a7bc8 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -17,12 +17,14 @@ import freezegun import mock +import OpenSSL import pytest import requests import requests.adapters from six.moves import http_client import google.auth.credentials +import google.auth.transport._mtls_helper import google.auth.transport.requests from tests.transport import compliance @@ -150,6 +152,34 @@ def send(self, request, **kwargs): return super(TimeTickAdapterStub, self).send(request, **kwargs) +class TestMutualTlsAdapter(object): + @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") + @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for") + def test_success(self, mock_proxy_manager_for, mock_init_poolmanager): + adapter = google.auth.transport.requests._MutualTlsAdapter( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + adapter.init_poolmanager() + mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager) + + adapter.proxy_manager_for() + mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager) + + def test_invalid_cert_or_key(self): + with pytest.raises(OpenSSL.crypto.Error): + google.auth.transport.requests._MutualTlsAdapter( + b"invalid cert", b"invalid key" + ) + + @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) + def test_import_error(self): + with pytest.raises(ImportError): + google.auth.transport.requests._MutualTlsAdapter( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + def make_response(status=http_client.OK, data=None): response = requests.Response() response.status_code = status @@ -157,7 +187,7 @@ def make_response(status=http_client.OK, data=None): return response -class TestAuthorizedHttp(object): +class TestAuthorizedSession(object): TEST_URL = "http://example.com/" def test_constructor(self): @@ -326,3 +356,62 @@ def test_request_timeout_w_refresh_timeout_timeout_error(self, frozen_time): authed_session.request( "GET", self.TEST_URL, timeout=60, max_allowed_time=2.9 ) + + def test_configure_mtls_channel_with_callback(self): + mock_callback = mock.Mock() + mock_callback.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel(mock_callback) + + assert auth_session.is_mtls + assert isinstance( + auth_session.adapters["https://"], + google.auth.transport.requests._MutualTlsAdapter, + ) + + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_with_metadata(self, mock_get_client_cert_and_key): + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel() + + assert auth_session.is_mtls + assert isinstance( + auth_session.adapters["https://"], + google.auth.transport.requests._MutualTlsAdapter, + ) + + @mock.patch.object(google.auth.transport.requests._MutualTlsAdapter, "__init__") + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_non_mtls( + self, mock_get_client_cert_and_key, mock_adapter_ctor + ): + mock_get_client_cert_and_key.return_value = (False, None, None) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel() + + assert not auth_session.is_mtls + + # Assert _MutualTlsAdapter constructor is not called. + mock_adapter_ctor.assert_not_called() diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py index 8a307332a..67833c3a8 100644 --- a/tests/transport/test_urllib3.py +++ b/tests/transport/test_urllib3.py @@ -13,10 +13,13 @@ # limitations under the License. import mock +import OpenSSL +import pytest from six.moves import http_client import urllib3 import google.auth.credentials +import google.auth.transport._mtls_helper import google.auth.transport.urllib3 from tests.transport import compliance @@ -77,6 +80,27 @@ def __init__(self, status=http_client.OK, data=None): self.data = data +class TestMakeMutualTlsHttp(object): + def test_success(self): + http = google.auth.transport.urllib3._make_mutual_tls_http( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + assert isinstance(http, urllib3.PoolManager) + + def test_crypto_error(self): + with pytest.raises(OpenSSL.crypto.Error): + google.auth.transport.urllib3._make_mutual_tls_http( + b"invalid cert", b"invalid key" + ) + + @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) + def test_import_error(self): + with pytest.raises(ImportError): + google.auth.transport.urllib3._make_mutual_tls_http( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + class TestAuthorizedHttp(object): TEST_URL = "http://example.com" @@ -138,3 +162,65 @@ def test_proxies(self): authed_http.headers = mock.sentinel.headers assert authed_http.headers == http.headers + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + def test_configure_mtls_channel_with_callback(self, mock_make_mutual_tls_http): + callback = mock.Mock() + callback.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock(), http=mock.Mock() + ) + + with pytest.warns(UserWarning): + is_mtls = authed_http.configure_mtls_channel(callback) + + assert is_mtls + mock_make_mutual_tls_http.assert_called_once_with( + cert=pytest.public_cert_bytes, key=pytest.private_key_bytes + ) + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_with_metadata( + self, mock_get_client_cert_and_key, mock_make_mutual_tls_http + ): + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock() + ) + + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + is_mtls = authed_http.configure_mtls_channel() + + assert is_mtls + mock_get_client_cert_and_key.assert_called_once() + mock_make_mutual_tls_http.assert_called_once_with( + cert=pytest.public_cert_bytes, key=pytest.private_key_bytes + ) + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_non_mtls( + self, mock_get_client_cert_and_key, mock_make_mutual_tls_http + ): + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock() + ) + + mock_get_client_cert_and_key.return_value = (False, None, None) + is_mtls = authed_http.configure_mtls_channel() + + assert not is_mtls + mock_get_client_cert_and_key.assert_called_once() + mock_make_mutual_tls_http.assert_not_called()