Skip to content

Commit

Permalink
Let AWS KMS integration in Python use boto3 instead of wrapping the C…
Browse files Browse the repository at this point in the history
…++ integration.

This should not change the behavior of the current API, it implements the same as
cc/integration/awskms/aws_kms_client.cc
in Python.

PiperOrigin-RevId: 520254665
Change-Id: Iddcf4c5d9c84694c24708b32cb58496c76716a1f
  • Loading branch information
juergw authored and copybara-github committed Mar 29, 2023
1 parent 2394c18 commit 4732dae
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 19 deletions.
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
absl-py==1.3.0
protobuf==4.21.9
boto3==1.26.89
36 changes: 34 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
#
# This file is autogenerated by pip-compile with python 3.10
# To update, run:
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --generate-hashes --output-file=requirements.txt requirements.in
#
absl-py==1.3.0 \
--hash=sha256:34995df9bd7a09b3b8749e230408f5a2a2dd7a68a0d33c12a3d0cb15a041a507 \
--hash=sha256:463c38a08d2e4cef6c498b76ba5bd4858e4c6ef51da1a5a1f27139a022e20248
# via -r requirements.in
boto3==1.26.89 \
--hash=sha256:09929b24aaec4951e435d53d31f800e2ca52244af049dc11e5385ce062e106e9 \
--hash=sha256:e819812f16fab46fadf9b2853a46aaa126e108e7f038502dde555ebbbfc80133
# via -r requirements.in
botocore==1.29.89 \
--hash=sha256:ac8da651f73a9d5759cf5d80beba68deda407e56aaaeb10d249fd557459f3b56 \
--hash=sha256:b757e59feca82ac62934f658918133116b4535cf66f1d72ff4935fa24e522527
# via
# boto3
# s3transfer
jmespath==1.0.1 \
--hash=sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980 \
--hash=sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe
# via
# boto3
# botocore
protobuf==4.21.9 \
--hash=sha256:2c9c2ed7466ad565f18668aa4731c535511c5d9a40c6da39524bccf43e441719 \
--hash=sha256:48e2cd6b88c6ed3d5877a3ea40df79d08374088e89bedc32557348848dff250b \
Expand All @@ -24,3 +40,19 @@ protobuf==4.21.9 \
--hash=sha256:e575c57dc8b5b2b2caa436c16d44ef6981f2235eb7179bfc847557886376d740 \
--hash=sha256:f9eae277dd240ae19bb06ff4e2346e771252b0e619421965504bd1b1bba7c5fa
# via -r requirements.in
python-dateutil==2.8.2 \
--hash=sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86 \
--hash=sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9
# via botocore
s3transfer==0.6.0 \
--hash=sha256:06176b74f3a15f61f1b4f25a1fc29a4429040b7647133a463da8fa5bd28d5ecd \
--hash=sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947
# via boto3
six==1.16.0 \
--hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \
--hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254
# via python-dateutil
urllib3==1.26.15 \
--hash=sha256:8a388717b9476f934a21484e8c8e61875ab60644d29b9b39e11e4b9dc1c6b305 \
--hash=sha256:aa751d169e23c7479ce47a0cb0da579e3ede798f994f5816a74e4f4500dcea42
# via botocore
4 changes: 3 additions & 1 deletion tink/integration/awskms/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ py_library(
srcs = ["_aws_kms_client.py"],
srcs_version = "PY3",
deps = [
"//tink:tink_python",
"//tink/aead",
"//tink/aead:_kms_aead_key_manager",
"//tink/cc/pybind:tink_bindings",
"//tink/core",
requirement("boto3"),
],
)

Expand All @@ -35,6 +36,7 @@ py_test(
srcs_version = "PY3",
deps = [
":awskms",
":_aws_kms_client",
"//tink:tink_python",
"//tink/testing:helper",
requirement("absl-py"),
Expand Down
2 changes: 1 addition & 1 deletion tink/integration/awskms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""StreamingAead package."""
"""AWS KMS integration package."""

from tink.integration.awskms import _aws_kms_client

Expand Down
127 changes: 112 additions & 15 deletions tink/integration/awskms/_aws_kms_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,95 @@
# limitations under the License.
"""A client for AWS KMS."""

import binascii
import configparser
import re

from typing import Tuple, Any, Dict

import boto3
from botocore import exceptions

import tink
from tink import aead
from tink import core
from tink.aead import _kms_aead_key_manager
from tink.cc.pybind import tink_bindings


AWS_KEYURI_PREFIX = 'aws-kms://'


def _encryption_context(associated_data: bytes) -> Dict[str, str]:
if associated_data:
hex_associated_data = binascii.hexlify(associated_data).decode('utf-8')
return {'associatedData': hex_associated_data}
else:
return dict()


class _AwsKmsAead(aead.Aead):
"""Implements the Aead interface for AWS KMS."""

def __init__(self, client: Any, key_arn: str) -> None:
self.client = client
self.key_arn = key_arn

def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
try:
response = self.client.encrypt(
KeyId=self.key_arn,
Plaintext=plaintext,
EncryptionContext=_encryption_context(associated_data),
)
return response['CiphertextBlob']
except exceptions.ClientError as e:
raise tink.TinkError(e)

def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
try:
response = self.client.decrypt(
KeyId=self.key_arn,
CiphertextBlob=ciphertext,
EncryptionContext=_encryption_context(associated_data),
)
if response['KeyId'] != self.key_arn:
raise tink.TinkError(
'invalid key id: got %s, want %s'
% (self.key_arn, response['KeyId'])
)
return response['Plaintext']
except exceptions.ClientError as e:
raise tink.TinkError(e)


def _key_uri_to_key_arn(key_uri: str) -> str:
if not key_uri.startswith(AWS_KEYURI_PREFIX):
raise tink.TinkError('invalid key URI')
return key_uri[len(AWS_KEYURI_PREFIX) :]


def _parse_config(config_path: str) -> Tuple[str, str]:
"""Returns ('aws_access_key_id', 'aws_secret_access_key') from a config."""
config = configparser.ConfigParser()
config.read(config_path)
if 'default' not in config:
raise ValueError('invalid config: default not found')
default = config['default']
if 'aws_access_key_id' not in default:
raise ValueError('invalid config: aws_access_key_id not found')
aws_access_key_id = default['aws_access_key_id']
if 'aws_secret_access_key' not in default:
raise ValueError('invalid config: aws_secret_access_key not found')
aws_secret_access_key = default['aws_secret_access_key']
return (aws_access_key_id, aws_secret_access_key)


def _get_region_from_key_arn(key_arn: str) -> str:
# An AWS key ARN is of the form
# arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab.
key_arn_parts = key_arn.split(':')
if len(key_arn_parts) < 6:
raise tink.TinkError('invalid key id')
return key_arn_parts[3]


class AwsKmsClient(_kms_aead_key_manager.KmsClient):
Expand All @@ -44,29 +126,32 @@ def __init__(self, key_uri: str, credentials_path: str):
ValueError: If the path or filename of the credentials is invalid.
TinkError: If the key uri is not valid.
"""

match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri)
if not key_uri:
self._key_uri = ''
elif match:
self._key_uri = key_uri
self._key_arn = None
else:
raise core.TinkError

self.cc_client = tink_bindings.AwsKmsClient(key_uri, credentials_path)
match = re.match('aws-kms://arn:aws:kms:([a-z0-9-]+):', key_uri)
if not match:
raise tink.TinkError('invalid key URI')
self._key_arn = _key_uri_to_key_arn(key_uri)
aws_access_key_id, aws_secret_access_key = _parse_config(credentials_path)
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key

def does_support(self, key_uri: str) -> bool:
"""Returns true iff this client supports KMS key specified in 'key_uri'.
"""Returns true if this client supports KMS key specified in 'key_uri'.
Args:
key_uri: Text, URI of the key to be checked.
Returns: A boolean value which is true if the key is supported and false
otherwise.
"""
return self.cc_client.does_support(key_uri)
if not key_uri.startswith(AWS_KEYURI_PREFIX):
return False
if not self._key_arn:
return True
return _key_uri_to_key_arn(key_uri) == self._key_arn

@core.use_tink_errors
def get_aead(self, key_uri: str) -> aead.Aead:
"""Returns an Aead-primitive backed by KMS key specified by 'key_uri'.
Expand All @@ -79,8 +164,20 @@ def get_aead(self, key_uri: str) -> aead.Aead:
Raises:
TinkError: If the key_uri is not supported.
"""

return aead.AeadCcToPyWrapper(self.cc_client.get_aead(key_uri))
if not self.does_support(key_uri):
if self._key_arn:
raise tink.TinkError(
'This client is bound to %s and cannot use key %s' %
(self._key_arn, key_uri))
raise tink.TinkError(
'This client does not support key %s' % key_uri)
key_arn = _key_uri_to_key_arn(key_uri)
session = boto3.session.Session(
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
region_name=_get_region_from_key_arn(key_arn),
)
return _AwsKmsAead(session.client('kms'), key_arn)

@classmethod
def register_client(cls, key_uri, credentials_path) -> None:
Expand Down
57 changes: 57 additions & 0 deletions tink/integration/awskms/_aws_kms_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@

import os

import tempfile

from absl.testing import absltest

import tink
from tink.integration import awskms
from tink.integration.awskms import _aws_kms_client
from tink.testing import helper


CREDENTIAL_PATH = os.path.join(helper.tink_py_testdata_path(),
'aws/credentials.ini')
KEY_URI = ('aws-kms://arn:aws:kms:us-east-2:235739564943:key/'
Expand Down Expand Up @@ -63,6 +67,59 @@ def test_wrong_credentials_path(self):
with self.assertRaises(ValueError):
awskms.AwsKmsClient(KEY_URI, '../credentials.txt')

def test_parse_valid_credentials_works(self):
config_file = tempfile.NamedTemporaryFile(delete=False)
with open(config_file.name, 'w') as f:
f.write("""
[otherSection]
aws_access_key_id = other_key_id
aws_secret_access_key = other_key
[default]
aws_access_key_id = key_id_123
aws_secret_access_key = key_123""")

aws_access_key_id, aws_secret_access_key = _aws_kms_client._parse_config(
config_file.name
)
self.assertEqual(aws_access_key_id, 'key_id_123')
self.assertEqual(aws_secret_access_key, 'key_123')

os.unlink(config_file.name)

def test_parse_credentials_without_key_id_fails(self):
config_file = tempfile.NamedTemporaryFile(delete=False)
with open(config_file.name, 'w') as f:
f.write("""
[default]
aws_secret_access_key = key_123""")
with self.assertRaises(ValueError):
_aws_kms_client._parse_config(config_file.name)

os.unlink(config_file.name)

def test_parse_credentials_without_key_fails(self):
config_file = tempfile.NamedTemporaryFile(delete=False)
with open(config_file.name, 'w') as f:
f.write("""
[default]
aws_secret_access_key = key_123""")
with self.assertRaises(ValueError):
_aws_kms_client._parse_config(config_file.name)

os.unlink(config_file.name)

def test_parse_credentials_without_default_section_fails(self):
config_file = tempfile.NamedTemporaryFile(delete=False)
with open(config_file.name, 'w') as f:
f.write("""
[otherSection]
aws_access_key_id = other_key_id
aws_secret_access_key = other_key""")
with self.assertRaises(ValueError):
_aws_kms_client._parse_config(config_file.name)

os.unlink(config_file.name)

if __name__ == '__main__':
absltest.main()

0 comments on commit 4732dae

Please sign in to comment.