Skip to content

Commit

Permalink
feat: add mTLS ADC support for HTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Mar 10, 2020
1 parent b2dd77f commit 1dee698
Show file tree
Hide file tree
Showing 11 changed files with 922 additions and 6 deletions.
153 changes: 153 additions & 0 deletions google/auth/transport/_mtls_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for getting mTLS cert and key, for internal use only."""

import json
import logging
from os import path
import re
import subprocess

CONTEXT_AWARE_METADATA_PATH = "~/.secureConnect/context_aware_metadata.json"
_CERT_PROVIDER_COMMAND = "cert_provider_command"
_CERT_REGEX = re.compile(
b"-----BEGIN CERTIFICATE-----.+-----END CERTIFICATE-----\r?\n?", re.DOTALL
)

# support various format of key files, e.g.
# "-----BEGIN PRIVATE KEY-----...",
# "-----BEGIN EC PRIVATE KEY-----...",
# "-----BEGIN RSA PRIVATE KEY-----..."
_KEY_REGEX = re.compile(
b"-----BEGIN [A-Z ]*PRIVATE KEY-----.+-----END [A-Z ]*PRIVATE KEY-----\r?\n?",
re.DOTALL,
)

_LOGGER = logging.getLogger(__name__)


def _check_dca_metadata_path(metadata_path):
"""Checks for context aware metadata. If it exists, returns the absolute path;
otherwise returns None.
Args:
metadata_path (str): context aware metadata path.
Returns:
str: absolute path if exists and None otherwise.
"""
metadata_path = path.expanduser(metadata_path)
if not path.exists(metadata_path):
_LOGGER.debug("%s is not found, skip client SSL authentication.", metadata_path)
return None
return metadata_path


def _read_dca_metadata_file(metadata_path):
"""Loads context aware metadata from the given path.
Args:
metadata_path (str): context aware metadata path.
Returns:
Dict[str, str]: The metadata.
Raises:
ValueError: If failed to parse metadata as JSON.
"""
with open(metadata_path) as f:
metadata = json.load(f)

return metadata


def get_client_ssl_credentials(metadata_json):
"""Returns the client side mTLS cert and key.
Args:
metadata_json (Dict[str, str]): metadata JSON file which contains the cert
provider command.
Returns:
Tuple[bytes, bytes]: client certificate and key, both in PEM format.
Raises:
OSError: If the cert provider command failed to run.
RuntimeError: If the cert provider command has a runtime error.
ValueError: If the metadata json file doesn't contain the cert provider
command or if the command doesn't produce both the client certificate
and client key.
"""
# TODO: implement an in-memory cache of cert and key so we don't have to
# run cert provider command every time.

# Check the cert provider command existence in the metadata json file.
if _CERT_PROVIDER_COMMAND not in metadata_json:
raise ValueError("Cert provider command is not found")

# Execute the command. It throws OsError in case of system failure.
command = metadata_json[_CERT_PROVIDER_COMMAND]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()

# Check cert provider command execution error.
if process.returncode != 0:
raise RuntimeError(
"Cert provider command returns non-zero status code %s" % process.returncode
)

# Extract certificate (chain) and key.
cert_match = re.findall(_CERT_REGEX, stdout)
if len(cert_match) != 1:
raise ValueError("Client SSL certificate is missing or invalid")
key_match = re.findall(_KEY_REGEX, stdout)
if len(key_match) != 1:
raise ValueError("Client SSL key is missing or invalid")
return cert_match[0], key_match[0]


def get_client_cert_and_key(client_cert_callback=None):
"""Returns the client side certificate and private key. The function first
tries to get certificate and key from client_cert_callback; if the callback
is None or doesn't provide certificate and key, the function tries application
default SSL credentials.
Args:
client_cert_callback (Optional[Callable[[], (bool, bytes, bytes)]]): A
callback which returns a bool indicating if the call is successful,
and client certificate bytes and private key bytes both in PEM format.
Returns:
Tuple[bool, bytes, bytes]:
A boolean indicating if cert and key are obtained, the cert bytes
and key bytes both in PEM format.
Raises:
OSError: If the cert provider command failed to run.
RuntimeError: If the cert provider command has a runtime error.
ValueError: If the metadata json file doesn't contain the cert provider
command or if the command doesn't produce both the client certificate
and client key.
"""
if client_cert_callback:
return client_cert_callback()

metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH)
if metadata_path:
metadata = _read_dca_metadata_file(metadata_path)
cert, key = get_client_ssl_credentials(metadata)
return True, cert, key

return False, None, None
133 changes: 133 additions & 0 deletions google/auth/transport/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@
)
import requests.adapters # pylint: disable=ungrouped-imports
import requests.exceptions # pylint: disable=ungrouped-imports
from requests.packages.urllib3.util.ssl_ import (
create_urllib3_context,
) # pylint: disable=ungrouped-imports
import six # pylint: disable=ungrouped-imports

from google.auth import exceptions
from google.auth import transport
import google.auth.transport._mtls_helper

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -182,6 +186,52 @@ def __call__(
six.raise_from(new_exc, caught_exc)


class _MutualTlsAdapter(requests.adapters.HTTPAdapter):
"""
A TransportAdapter that enables mutual TLS.
Args:
cert (bytes): client certificate in PEM format
key (bytes): client private key in PEM format
Raises:
ImportError: if certifi or pyOpenSSL is not installed
OpenSSL.crypto.Error: if client cert or key is invalid
"""

def __init__(self, cert, key):
import certifi
from OpenSSL import crypto
import urllib3.contrib.pyopenssl

urllib3.contrib.pyopenssl.inject_into_urllib3()

pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key)
x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert)

ctx_poolmanager = create_urllib3_context()
ctx_poolmanager.load_verify_locations(cafile=certifi.where())
ctx_poolmanager._ctx.use_certificate(x509)
ctx_poolmanager._ctx.use_privatekey(pkey)
self._ctx_poolmanager = ctx_poolmanager

ctx_proxymanager = create_urllib3_context()
ctx_proxymanager.load_verify_locations(cafile=certifi.where())
ctx_proxymanager._ctx.use_certificate(x509)
ctx_proxymanager._ctx.use_privatekey(pkey)
self._ctx_proxymanager = ctx_proxymanager

super(_MutualTlsAdapter, self).__init__()

def init_poolmanager(self, *args, **kwargs):
kwargs["ssl_context"] = self._ctx_poolmanager
super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs)

def proxy_manager_for(self, *args, **kwargs):
kwargs["ssl_context"] = self._ctx_proxymanager
return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs)


class AuthorizedSession(requests.Session):
"""A Requests Session class with credentials.
Expand All @@ -198,6 +248,49 @@ class AuthorizedSession(requests.Session):
The underlying :meth:`request` implementation handles adding the
credentials' headers to the request and refreshing credentials as needed.
This class also supports mutual TLS via :meth:`configure_mtls_channel`
method. This method first tries to load client certificate and private key
using the given client_cert_callabck; if callback is None or fails, it tries
to load application default SSL credentials. Exceptions are raised if there
are problems with the certificate, private key, or the loading process, so
it should be called within a try/except block.
First we create an :class:`AuthorizedSession` instance and specify the endpoints::
regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics'
mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics'
authed_session = AuthorizedSession(credentials)
Now we can pass a callback to :meth:`configure_mtls_channel`::
def my_cert_callback():
# some code to load client cert bytes and private key bytes, both in
# PEM format.
some_code_to_load_client_cert_and_key()
if loaded:
return True, cert, key
else:
return False, None, None
# Always call configure_mtls_channel within a try/except block.
try:
authed_session.configure_mtls_channel(my_cert_callback)
except:
# handle exceptions.
if authed_session.is_mtls:
response = authed_session.request('GET', mtls_endpoint)
else:
response = authed_session.request('GET', regular_endpoint)
You can alternatively use application default SSL credentials like this::
try:
authed_session.configure_mtls_channel()
except:
# handle exceptions.
Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to the request.
Expand Down Expand Up @@ -229,6 +322,7 @@ def __init__(
self._refresh_status_codes = refresh_status_codes
self._max_refresh_attempts = max_refresh_attempts
self._refresh_timeout = refresh_timeout
self._is_mtls = False

if auth_request is None:
auth_request_session = requests.Session()
Expand All @@ -247,6 +341,40 @@ def __init__(
# credentials.refresh).
self._auth_request = auth_request

def configure_mtls_channel(self, client_cert_callback=None):
"""Configure the client certificate and key for SSL connection.
If client certificate and key are successfully obtained (from the given
client_cert_callabck or from application default SSL credentials), a
:class:`_MutualTlsAdapter` instance will be mounted to "https://" prefix.
Args:
client_cert_callabck (Optional[Callable[[], (bool, bytes, bytes)]]):
The optional callback returns a boolean indicating if the call
is successful, and the client certificate and private key bytes
both in PEM format.
If the call is not succesful, application default SSL credentials
will be used.
Raises:
ImportError: If certifi or pyOpenSSL is not installed.
OpenSSL.crypto.Error: If client cert or key is invalid.
OSError: If the cert provider command launch fails during the
application default SSL credentials loading process.
RuntimeError: If the cert provider command has a runtime error during
the application default SSL credentials loading process.
ValueError: If the context aware metadata file is malformed or the
cert provider command doesn't produce both client certicate and
key during the application default SSL credentials loading process.
"""
self._is_mtls, cert, key = google.auth.transport._mtls_helper.get_client_cert_and_key(
client_cert_callback
)

if self._is_mtls:
mtls_adapter = _MutualTlsAdapter(cert, key)
self.mount("https://", mtls_adapter)

def request(
self,
method,
Expand Down Expand Up @@ -361,3 +489,8 @@ def request(
)

return response

@property
def is_mtls(self):
"""Indicates if the created SSL channel is mutual TLS."""
return self._is_mtls
Loading

0 comments on commit 1dee698

Please sign in to comment.