Skip to content

Commit

Permalink
move get_s3_connection, reduce complexity and increase coverage (ansi…
Browse files Browse the repository at this point in the history
…ble-collections#1139)

move get_s3_connection, reduce complexity and increase coverage

SUMMARY

get_s3_connection is duplicated into modules s3_object and s3_object_info, the goal of this pull request is to move to a single place, reduce code complexity and increase coverage

ISSUE TYPE


Feature Pull Request

COMPONENT NAME

ADDITIONAL INFORMATION

Reviewed-by: Mark Chappell <None>
Reviewed-by: Gonéri Le Bouder <goneri@lebouder.net>
Reviewed-by: Bikouo Aubin <None>
  • Loading branch information
abikouo authored Oct 19, 2022
1 parent 86b20e9 commit e3fbf6b
Show file tree
Hide file tree
Showing 7 changed files with 529 additions and 233 deletions.
3 changes: 3 additions & 0 deletions changelogs/fragments/module_utils_s3-unit-testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- module_utils.s3 - Refactor get_s3_connection into a module_utils for S3 modules and expand module_utils.s3 unit tests (https://github.com/ansible-collections/amazon.aws/pull/1139).
164 changes: 116 additions & 48 deletions plugins/module_utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

from ansible.module_utils.basic import to_text
from urllib.parse import urlparse

from .botocore import boto3_conn

try:
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
except ImportError:
pass # Handled by the calling module
Expand All @@ -22,33 +28,49 @@
import string


def s3_head_objects(client, parts, bucket, obj, versionId):
args = {"Bucket": bucket, "Key": obj}
if versionId:
args["VersionId"] = versionId

for part in range(1, parts + 1):
args["PartNumber"] = part
yield client.head_object(**args)


def calculate_checksum_with_file(client, parts, bucket, obj, versionId, filename):
digests = []
with open(filename, 'rb') as f:
for head in s3_head_objects(client, parts, bucket, obj, versionId):
digests.append(md5(f.read(int(head['ContentLength']))).digest())

digest_squared = b''.join(digests)
return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests))


def calculate_checksum_with_content(client, parts, bucket, obj, versionId, content):
digests = []
offset = 0
for head in s3_head_objects(client, parts, bucket, obj, versionId):
length = int(head['ContentLength'])
digests.append(md5(content[offset:offset + length]).digest())
offset += length

digest_squared = b''.join(digests)
return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests))


def calculate_etag(module, filename, etag, s3, bucket, obj, version=None):
if not HAS_MD5:
return None

if '-' in etag:
# Multi-part ETag; a hash of the hashes of each part.
parts = int(etag[1:-1].split('-')[1])
digests = []

s3_kwargs = dict(
Bucket=bucket,
Key=obj,
)
if version:
s3_kwargs['VersionId'] = version

with open(filename, 'rb') as f:
for part_num in range(1, parts + 1):
s3_kwargs['PartNumber'] = part_num
try:
head = s3.head_object(**s3_kwargs)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
digests.append(md5(f.read(int(head['ContentLength']))))

digest_squared = md5(b''.join(m.digest() for m in digests))
return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests))
try:
return calculate_checksum_with_file(s3, parts, bucket, obj, version, filename)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
else: # Compute the MD5 sum normally
return '"{0}"'.format(module.md5(filename))

Expand All @@ -60,43 +82,89 @@ def calculate_etag_content(module, content, etag, s3, bucket, obj, version=None)
if '-' in etag:
# Multi-part ETag; a hash of the hashes of each part.
parts = int(etag[1:-1].split('-')[1])
digests = []
offset = 0

s3_kwargs = dict(
Bucket=bucket,
Key=obj,
)
if version:
s3_kwargs['VersionId'] = version

for part_num in range(1, parts + 1):
s3_kwargs['PartNumber'] = part_num
try:
head = s3.head_object(**s3_kwargs)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
length = int(head['ContentLength'])
digests.append(md5(content[offset:offset + length]))
offset += length

digest_squared = md5(b''.join(m.digest() for m in digests))
return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests))
try:
return calculate_checksum_with_content(s3, parts, bucket, obj, version, content)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
else: # Compute the MD5 sum normally
return '"{0}"'.format(md5(content).hexdigest())


def validate_bucket_name(module, name):
def validate_bucket_name(name):
# See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
if len(name) < 3:
module.fail_json(msg='the length of an S3 bucket must be at least 3 characters')
return 'the length of an S3 bucket must be at least 3 characters'
if len(name) > 63:
module.fail_json(msg='the length of an S3 bucket cannot exceed 63 characters')
return 'the length of an S3 bucket cannot exceed 63 characters'

legal_characters = string.ascii_lowercase + ".-" + string.digits
illegal_characters = [c for c in name if c not in legal_characters]
if illegal_characters:
module.fail_json(msg='invalid character(s) found in the bucket name')
return 'invalid character(s) found in the bucket name'
if name[-1] not in string.ascii_lowercase + string.digits:
module.fail_json(msg='bucket names must begin and end with a letter or number')
return True
return 'bucket names must begin and end with a letter or number'
return None


# Spot special case of fakes3.
def is_fakes3(url):
""" Return True if endpoint_url has scheme fakes3:// """
result = False
if url is not None:
result = urlparse(url).scheme in ('fakes3', 'fakes3s')
return result


def parse_fakes3_endpoint(url):
fakes3 = urlparse(url)
protocol = "http"
port = fakes3.port or 80
if fakes3.scheme == 'fakes3s':
protocol = "https"
port = fakes3.port or 443
endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}"
use_ssl = bool(fakes3.scheme == 'fakes3s')
return {"endpoint": endpoint_url, "use_ssl": use_ssl}


def parse_ceph_endpoint(url):
ceph = urlparse(url)
use_ssl = bool(ceph.scheme == 'https')
return {"endpoint": url, "use_ssl": use_ssl}


def parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4):
result = {"endpoint": url}
config = {}
if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"):
config["signature_version"] = "s3v4"
if dualstack:
config["s3"] = {"use_dualstack_endpoint": True}
if config != {}:
result["config"] = Config(**config)
return result


def s3_conn_params(mode, encryption_mode, dualstack, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
params = {"conn_type": "client", "resource": "s3", "region": location, **aws_connect_kwargs}
if ceph:
endpoint_p = parse_ceph_endpoint(endpoint_url)
elif is_fakes3(endpoint_url):
endpoint_p = parse_fakes3_endpoint(endpoint_url)
else:
endpoint_p = parse_default_endpoint(endpoint_url, mode, encryption_mode, dualstack, sig_4)

params.update(endpoint_p)
return params


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
s3_conn = s3_conn_params(module.params.get("mode"),
module.params.get("encryption_mode"),
module.params.get("dualstack"),
aws_connect_kwargs,
location,
ceph,
endpoint_url,
sig_4)
return boto3_conn(module, **s3_conn)
4 changes: 3 additions & 1 deletion plugins/modules/s3_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,9 @@ def main():
region, _ec2_url, aws_connect_kwargs = get_aws_connection_info(module, boto3=True)

if module.params.get('validate_bucket_name'):
validate_bucket_name(module, module.params["name"])
err = validate_bucket_name(module.params["name"])
if err:
module.fail_json(msg=err)

if region in ('us-east-1', '', None):
# default to US Standard region
Expand Down
50 changes: 4 additions & 46 deletions plugins/modules/s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,13 @@
except ImportError:
pass # Handled by AnsibleAWSModule

from ansible.module_utils.basic import to_text
from ansible.module_utils.basic import to_native
from ansible.module_utils.six.moves.urllib.parse import urlparse

from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_message
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn
from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_tag_list
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict
Expand Down Expand Up @@ -845,48 +843,6 @@ def copy_object_to_bucket(module, s3, bucket, obj, encrypt, metadata, validate,
module.fail_json_aws(e, msg="Failed while copying object %s from bucket %s." % (obj, module.params['copy_src'].get('Bucket')))


def is_fakes3(endpoint_url):
""" Return True if endpoint_url has scheme fakes3:// """
if endpoint_url is not None:
return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
else:
return False


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https',
region=location, endpoint=endpoint_url, **aws_connect_kwargs)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
port = fakes3.port
if fakes3.scheme == 'fakes3s':
protocol = "https"
if port is None:
port = 443
else:
protocol = "http"
if port is None:
port = 80
params = dict(module=module, conn_type='client', resource='s3', region=location,
endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
else:
params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs)
if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms':
params['config'] = botocore.client.Config(signature_version='s3v4')
elif module.params['mode'] in ('get', 'getstr', 'geturl') and sig_4:
params['config'] = botocore.client.Config(signature_version='s3v4')
if module.params['dualstack']:
dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True})
if 'config' in params:
params['config'] = params['config'].merge(dualconf)
else:
params['config'] = dualconf
return boto3_conn(**params)


def get_current_object_tags_dict(s3, bucket, obj, version=None):
try:
if version:
Expand Down Expand Up @@ -1040,7 +996,9 @@ def main():
bucket_canned_acl = ["private", "public-read", "public-read-write", "authenticated-read"]

if module.params.get('validate_bucket_name'):
validate_bucket_name(module, bucket)
err = validate_bucket_name(bucket)
if err:
module.fail_json(msg=err)

if overwrite not in ['always', 'never', 'different', 'latest']:
if module.boolean(overwrite):
Expand Down
50 changes: 2 additions & 48 deletions plugins/modules/s3_object_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,16 +440,13 @@
except ImportError:
pass # Handled by AnsibleAWSModule

from ansible.module_utils.basic import to_text
from ansible.module_utils.six.moves.urllib.parse import urlparse

from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import camel_dict_to_snake_dict
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn
from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection


def describe_s3_object_acl(connection, bucket_name, object_name):
Expand Down Expand Up @@ -670,49 +667,6 @@ def object_check(connection, module, bucket_name, object_name):
module.fail_json_aws(e, msg="The object %s does not exist or is missing access permissions." % object_name)


# To get S3 connection, in case of dealing with ceph, dualstack, etc.
def is_fakes3(endpoint_url):
""" Return True if endpoint_url has scheme fakes3:// """
if endpoint_url is not None:
return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
else:
return False


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https',
region=location, endpoint=endpoint_url, **aws_connect_kwargs)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
port = fakes3.port
if fakes3.scheme == 'fakes3s':
protocol = "https"
if port is None:
port = 443
else:
protocol = "http"
if port is None:
port = 80
params = dict(module=module, conn_type='client', resource='s3', region=location,
endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
else:
params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs)
if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms':
params['config'] = botocore.client.Config(signature_version='s3v4')
elif module.params['mode'] in ('get', 'getstr') and sig_4:
params['config'] = botocore.client.Config(signature_version='s3v4')
if module.params['dualstack']:
dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True})
if 'config' in params:
params['config'] = params['config'].merge(dualconf)
else:
params['config'] = dualconf
return boto3_conn(**params)


def main():

argument_spec = dict(
Expand All @@ -730,7 +684,7 @@ def main():
),
bucket_name=dict(required=True, type='str'),
object_name=dict(type='str'),
dualstack=dict(default='no', type='bool'),
dualstack=dict(default=False, type='bool'),
ceph=dict(default=False, type='bool', aliases=['rgw']),
)

Expand Down
Loading

0 comments on commit e3fbf6b

Please sign in to comment.