From 4732dae50118aa94eee09a1b7caa5e98371c6f6d Mon Sep 17 00:00:00 2001 From: Juerg Wullschleger Date: Wed, 29 Mar 2023 00:44:34 -0700 Subject: [PATCH] Let AWS KMS integration in Python use boto3 instead of wrapping the C++ integration. This should not change the behavior of the current API, it implements the same as cc/integration/awskms/aws_kms_client.cc in Python. PiperOrigin-RevId: 520254665 Change-Id: Iddcf4c5d9c84694c24708b32cb58496c76716a1f --- requirements.in | 1 + requirements.txt | 36 ++++- tink/integration/awskms/BUILD.bazel | 4 +- tink/integration/awskms/__init__.py | 2 +- tink/integration/awskms/_aws_kms_client.py | 127 +++++++++++++++--- .../awskms/_aws_kms_client_test.py | 57 ++++++++ 6 files changed, 208 insertions(+), 19 deletions(-) diff --git a/requirements.in b/requirements.in index 1c826af..e7aa751 100644 --- a/requirements.in +++ b/requirements.in @@ -1,2 +1,3 @@ absl-py==1.3.0 protobuf==4.21.9 +boto3==1.26.89 diff --git a/requirements.txt b/requirements.txt index 1c3b47b..98a1936 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # -# This file is autogenerated by pip-compile with python 3.10 -# To update, run: +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: # # pip-compile --generate-hashes --output-file=requirements.txt requirements.in # @@ -8,6 +8,22 @@ absl-py==1.3.0 \ --hash=sha256:34995df9bd7a09b3b8749e230408f5a2a2dd7a68a0d33c12a3d0cb15a041a507 \ --hash=sha256:463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248 # via -r requirements.in +boto3==1.26.89 \ + --hash=sha256:09929b24aaec4951e435d53d31f800e2ca52244af049dc11e5385ce062e106e9 \ + --hash=sha256:e819812f16fab46fadf9b2853a46aaa126e108e7f038502dde555ebbbfc80133 + # via -r requirements.in +botocore==1.29.89 \ + --hash=sha256:ac8da651f73a9d5759cf5d80beba68deda407e56aaaeb10d249fd557459f3b56 \ + --hash=sha256:b757e59feca82ac62934f658918133116b4535cf66f1d72ff4935fa24e522527 + # via + # boto3 + # s3transfer +jmespath==1.0.1 \ + --hash=sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980 \ + --hash=sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe + # via + # boto3 + # botocore protobuf==4.21.9 \ --hash=sha256:2c9c2ed7466ad565f18668aa4731c535511c5d9a40c6da39524bccf43e441719 \ --hash=sha256:48e2cd6b88c6ed3d5877a3ea40df79d08374088e89bedc32557348848dff250b \ @@ -24,3 +40,19 @@ protobuf==4.21.9 \ --hash=sha256:e575c57dc8b5b2b2caa436c16d44ef6981f2235eb7179bfc847557886376d740 \ --hash=sha256:f9eae277dd240ae19bb06ff4e2346e771252b0e619421965504bd1b1bba7c5fa # via -r requirements.in +python-dateutil==2.8.2 \ + --hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \ + --hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9 + # via botocore +s3transfer==0.6.0 \ + --hash=sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd \ + --hash=sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947 + # via boto3 +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 + # via python-dateutil +urllib3==1.26.15 \ + --hash=sha256:8a388717b9476f934a21484e8c8e61875ab60644d29b9b39e11e4b9dc1c6b305 \ + --hash=sha256:aa751d169e23c7479ce47a0cb0da579e3ede798f994f5816a74e4f4500dcea42 + # via botocore diff --git a/tink/integration/awskms/BUILD.bazel b/tink/integration/awskms/BUILD.bazel index 6513157..4551368 100644 --- a/tink/integration/awskms/BUILD.bazel +++ b/tink/integration/awskms/BUILD.bazel @@ -18,10 +18,11 @@ py_library( srcs = ["_aws_kms_client.py"], srcs_version = "PY3", deps = [ + "//tink:tink_python", "//tink/aead", "//tink/aead:_kms_aead_key_manager", - "//tink/cc/pybind:tink_bindings", "//tink/core", + requirement("boto3"), ], ) @@ -35,6 +36,7 @@ py_test( srcs_version = "PY3", deps = [ ":awskms", + ":_aws_kms_client", "//tink:tink_python", "//tink/testing:helper", requirement("absl-py"), diff --git a/tink/integration/awskms/__init__.py b/tink/integration/awskms/__init__.py index a678c35..c86dc9c 100644 --- a/tink/integration/awskms/__init__.py +++ b/tink/integration/awskms/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""StreamingAead package.""" +"""AWS KMS integration package.""" from tink.integration.awskms import _aws_kms_client diff --git a/tink/integration/awskms/_aws_kms_client.py b/tink/integration/awskms/_aws_kms_client.py index 86ebc43..822e1b4 100644 --- a/tink/integration/awskms/_aws_kms_client.py +++ b/tink/integration/awskms/_aws_kms_client.py @@ -13,13 +13,95 @@ # limitations under the License. """A client for AWS KMS.""" +import binascii +import configparser import re +from typing import Tuple, Any, Dict +import boto3 +from botocore import exceptions + +import tink from tink import aead -from tink import core from tink.aead import _kms_aead_key_manager -from tink.cc.pybind import tink_bindings + + +AWS_KEYURI_PREFIX = 'aws-kms://' + + +def _encryption_context(associated_data: bytes) -> Dict[str, str]: + if associated_data: + hex_associated_data = binascii.hexlify(associated_data).decode('utf-8') + return {'associatedData': hex_associated_data} + else: + return dict() + + +class _AwsKmsAead(aead.Aead): + """Implements the Aead interface for AWS KMS.""" + + def __init__(self, client: Any, key_arn: str) -> None: + self.client = client + self.key_arn = key_arn + + def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: + try: + response = self.client.encrypt( + KeyId=self.key_arn, + Plaintext=plaintext, + EncryptionContext=_encryption_context(associated_data), + ) + return response['CiphertextBlob'] + except exceptions.ClientError as e: + raise tink.TinkError(e) + + def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: + try: + response = self.client.decrypt( + KeyId=self.key_arn, + CiphertextBlob=ciphertext, + EncryptionContext=_encryption_context(associated_data), + ) + if response['KeyId'] != self.key_arn: + raise tink.TinkError( + 'invalid key id: got %s, want %s' + % (self.key_arn, response['KeyId']) + ) + return response['Plaintext'] + except exceptions.ClientError as e: + raise tink.TinkError(e) + + +def _key_uri_to_key_arn(key_uri: str) -> str: + if not key_uri.startswith(AWS_KEYURI_PREFIX): + raise tink.TinkError('invalid key URI') + return key_uri[len(AWS_KEYURI_PREFIX) :] + + +def _parse_config(config_path: str) -> Tuple[str, str]: + """Returns ('aws_access_key_id', 'aws_secret_access_key') from a config.""" + config = configparser.ConfigParser() + config.read(config_path) + if 'default' not in config: + raise ValueError('invalid config: default not found') + default = config['default'] + if 'aws_access_key_id' not in default: + raise ValueError('invalid config: aws_access_key_id not found') + aws_access_key_id = default['aws_access_key_id'] + if 'aws_secret_access_key' not in default: + raise ValueError('invalid config: aws_secret_access_key not found') + aws_secret_access_key = default['aws_secret_access_key'] + return (aws_access_key_id, aws_secret_access_key) + + +def _get_region_from_key_arn(key_arn: str) -> str: + # An AWS key ARN is of the form + # arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab. + key_arn_parts = key_arn.split(':') + if len(key_arn_parts) < 6: + raise tink.TinkError('invalid key id') + return key_arn_parts[3] class AwsKmsClient(_kms_aead_key_manager.KmsClient): @@ -44,19 +126,19 @@ def __init__(self, key_uri: str, credentials_path: str): ValueError: If the path or filename of the credentials is invalid. TinkError: If the key uri is not valid. """ - - match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri) if not key_uri: - self._key_uri = '' - elif match: - self._key_uri = key_uri + self._key_arn = None else: - raise core.TinkError - - self.cc_client = tink_bindings.AwsKmsClient(key_uri, credentials_path) + match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri) + if not match: + raise tink.TinkError('invalid key URI') + self._key_arn = _key_uri_to_key_arn(key_uri) + aws_access_key_id, aws_secret_access_key = _parse_config(credentials_path) + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key def does_support(self, key_uri: str) -> bool: - """Returns true iff this client supports KMS key specified in 'key_uri'. + """Returns true if this client supports KMS key specified in 'key_uri'. Args: key_uri: Text, URI of the key to be checked. @@ -64,9 +146,12 @@ def does_support(self, key_uri: str) -> bool: Returns: A boolean value which is true if the key is supported and false otherwise. """ - return self.cc_client.does_support(key_uri) + if not key_uri.startswith(AWS_KEYURI_PREFIX): + return False + if not self._key_arn: + return True + return _key_uri_to_key_arn(key_uri) == self._key_arn - @core.use_tink_errors def get_aead(self, key_uri: str) -> aead.Aead: """Returns an Aead-primitive backed by KMS key specified by 'key_uri'. @@ -79,8 +164,20 @@ def get_aead(self, key_uri: str) -> aead.Aead: Raises: TinkError: If the key_uri is not supported. """ - - return aead.AeadCcToPyWrapper(self.cc_client.get_aead(key_uri)) + if not self.does_support(key_uri): + if self._key_arn: + raise tink.TinkError( + 'This client is bound to %s and cannot use key %s' % + (self._key_arn, key_uri)) + raise tink.TinkError( + 'This client does not support key %s' % key_uri) + key_arn = _key_uri_to_key_arn(key_uri) + session = boto3.session.Session( + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + region_name=_get_region_from_key_arn(key_arn), + ) + return _AwsKmsAead(session.client('kms'), key_arn) @classmethod def register_client(cls, key_uri, credentials_path) -> None: diff --git a/tink/integration/awskms/_aws_kms_client_test.py b/tink/integration/awskms/_aws_kms_client_test.py index 5a0901b..129f2de 100644 --- a/tink/integration/awskms/_aws_kms_client_test.py +++ b/tink/integration/awskms/_aws_kms_client_test.py @@ -15,12 +15,16 @@ import os +import tempfile + from absl.testing import absltest import tink from tink.integration import awskms +from tink.integration.awskms import _aws_kms_client from tink.testing import helper + CREDENTIAL_PATH = os.path.join(helper.tink_py_testdata_path(), 'aws/credentials.ini') KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/' @@ -63,6 +67,59 @@ def test_wrong_credentials_path(self): with self.assertRaises(ValueError): awskms.AwsKmsClient(KEY_URI, '../credentials.txt') + def test_parse_valid_credentials_works(self): + config_file = tempfile.NamedTemporaryFile(delete=False) + with open(config_file.name, 'w') as f: + f.write(""" +[otherSection] +aws_access_key_id = other_key_id +aws_secret_access_key = other_key + +[default] +aws_access_key_id = key_id_123 +aws_secret_access_key = key_123""") + + aws_access_key_id, aws_secret_access_key = _aws_kms_client._parse_config( + config_file.name + ) + self.assertEqual(aws_access_key_id, 'key_id_123') + self.assertEqual(aws_secret_access_key, 'key_123') + + os.unlink(config_file.name) + + def test_parse_credentials_without_key_id_fails(self): + config_file = tempfile.NamedTemporaryFile(delete=False) + with open(config_file.name, 'w') as f: + f.write(""" +[default] +aws_secret_access_key = key_123""") + with self.assertRaises(ValueError): + _aws_kms_client._parse_config(config_file.name) + + os.unlink(config_file.name) + + def test_parse_credentials_without_key_fails(self): + config_file = tempfile.NamedTemporaryFile(delete=False) + with open(config_file.name, 'w') as f: + f.write(""" +[default] +aws_secret_access_key = key_123""") + with self.assertRaises(ValueError): + _aws_kms_client._parse_config(config_file.name) + + os.unlink(config_file.name) + + def test_parse_credentials_without_default_section_fails(self): + config_file = tempfile.NamedTemporaryFile(delete=False) + with open(config_file.name, 'w') as f: + f.write(""" +[otherSection] +aws_access_key_id = other_key_id +aws_secret_access_key = other_key""") + with self.assertRaises(ValueError): + _aws_kms_client._parse_config(config_file.name) + + os.unlink(config_file.name) if __name__ == '__main__': absltest.main()