Skip to content

Commit

Permalink
Added a --source-region parameter.
Browse files Browse the repository at this point in the history
This parameter ensures the ability to do
trans-region syncs, moves, and copies.  Tests
were expanded as well to better test
handling endpoints.
  • Loading branch information
kyleknap committed Aug 11, 2014
1 parent 57424a5 commit dc5c6f2
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 52 deletions.
13 changes: 9 additions & 4 deletions awscli/customizations/s3/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ class FileGenerator(object):
under the same common prefix. The generator yields corresponding
``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``.
"""
def __init__(self, service, endpoint, operation_name, follow_symlinks=True):
def __init__(self, service, endpoint, operation_name,
follow_symlinks=True, source_endpoint=None):
self._service = service
self._endpoint = endpoint
self._source_endpoint = endpoint
if source_endpoint:
self._source_endpoint = source_endpoint
self.operation_name = operation_name
self.follow_symlinks = follow_symlinks

Expand Down Expand Up @@ -91,7 +95,8 @@ def call(self, files):
last_update=last_update, src_type=src_type,
service=self._service, endpoint=self._endpoint,
dest_type=dest_type,
operation_name=self.operation_name)
operation_name=self.operation_name,
source_endpoint=self._source_endpoint)

def list_files(self, path, dir_op):
"""
Expand Down Expand Up @@ -190,7 +195,7 @@ def list_objects(self, s3_path, dir_op):
yield self._list_single_object(s3_path)
else:
operation = self._service.get_operation('ListObjects')
lister = BucketLister(operation, self._endpoint)
lister = BucketLister(operation, self._source_endpoint)
for key in lister.list_objects(bucket=bucket, prefix=prefix):
source_path, size, last_update = key
if size == 0 and source_path.endswith('/'):
Expand All @@ -216,7 +221,7 @@ def _list_single_object(self, s3_path):
operation = self._service.get_operation('HeadObject')
try:
response = operation.call(
self._endpoint, bucket=bucket, key=key)[1]
self._source_endpoint, bucket=bucket, key=key)[1]
except ClientError as e:
# We want to try to give a more helpful error message.
# This is what the customer is going to see so we want to
Expand Down
6 changes: 4 additions & 2 deletions awscli/customizations/s3/fileinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class FileInfo(TaskInfo):
def __init__(self, src, dest=None, compare_key=None, size=None,
last_update=None, src_type=None, dest_type=None,
operation_name=None, service=None, endpoint=None,
parameters=None):
parameters=None, source_endpoint=None):
super(FileInfo, self).__init__(src, src_type=src_type,
operation_name=operation_name,
service=service,
Expand All @@ -156,6 +156,7 @@ def __init__(self, src, dest=None, compare_key=None, size=None,
else:
self.parameters = {'acl': None,
'sse': None}
self.source_endpoint = source_endpoint

def _permission_to_param(self, permission):
if permission == 'read':
Expand Down Expand Up @@ -256,7 +257,8 @@ def delete(self):
"""
if (self.src_type == 's3'):
bucket, key = find_bucket_key(self.src)
params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key}
params = {'endpoint': self.source_endpoint, 'bucket': bucket,
'key': key}
response_data, http = operate(self.service, 'DeleteObject',
params)
else:
Expand Down
52 changes: 41 additions & 11 deletions awscli/customizations/s3/subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@
CONTENT_LANGUAGE = {'name': 'content-language', 'nargs': 1,
'help_text': ("The language the content is in.")}

SOURCE_REGION = {'name': 'source-region', 'nargs': 1,
'help_text': (
"When transferring objects from an s3 bucket to an s3 "
"bucket, this specifies the region of the source bucket."
" Note the region specified by ``--region`` or through "
"configuration of the CLI refers to the region of the "
"destination bucket. If ``--source-region`` is not "
"specified the region of the source will be the same "
"as the region of the destination bucket.")}

EXPIRES = {'name': 'expires', 'nargs': 1, 'help_text': ("The date and time at "
"which the object is no longer cacheable.")}

Expand Down Expand Up @@ -198,20 +208,22 @@
FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE,
SSE, STORAGE_CLASS, GRANTS, WEBSITE_REDIRECT, CONTENT_TYPE,
CACHE_CONTROL, CONTENT_DISPOSITION, CONTENT_ENCODING,
CONTENT_LANGUAGE, EXPIRES]
CONTENT_LANGUAGE, EXPIRES, SOURCE_REGION]

SYNC_ARGS = [DELETE, EXACT_TIMESTAMPS, SIZE_ONLY] + TRANSFER_ARGS


def get_endpoint(service, region, endpoint_url, verify):
return service.get_endpoint(region_name=region, endpoint_url=endpoint_url,
verify=verify)


class S3Command(BasicCommand):
def _run_main(self, parsed_args, parsed_globals):
self.service = self._session.get_service('s3')
self.endpoint = self._get_endpoint(self.service, parsed_globals)

def _get_endpoint(self, service, parsed_globals):
return service.get_endpoint(region_name=parsed_globals.region,
endpoint_url=parsed_globals.endpoint_url,
verify=parsed_globals.verify_ssl)
self.endpoint = get_endpoint(self.service, parsed_globals.region,
parsed_globals.endpoint_url,
parsed_globals.verify_ssl)


class ListCommand(S3Command):
Expand Down Expand Up @@ -363,6 +375,7 @@ def _run_main(self, parsed_args, parsed_globals):
cmd_params.check_force(parsed_globals)
cmd = CommandArchitecture(self._session, self.NAME,
cmd_params.parameters)
cmd.set_endpoints()
cmd.create_instructions()
return cmd.run()

Expand Down Expand Up @@ -463,10 +476,24 @@ def __init__(self, session, cmd, parameters):
self.parameters = parameters
self.instructions = []
self._service = self.session.get_service('s3')
self._endpoint = self._service.get_endpoint(
region_name=self.parameters['region'],
self._endpoint = None
self._source_endpoint = None

def set_endpoints(self):
self._endpoint = get_endpoint(
self._service,
region=self.parameters['region'],
endpoint_url=self.parameters['endpoint_url'],
verify=self.parameters['verify_ssl'])
verify=self.parameters['verify_ssl']
)
if self.parameters['source_region']:
if self.parameters['paths_type'] == 's3s3':
self._source_endpoint = get_endpoint(
self._service,
region=self.parameters['source_region'][0],
endpoint_url=None,
verify=self.parameters['verify_ssl']
)

def create_instructions(self):
"""
Expand Down Expand Up @@ -526,7 +553,8 @@ def run(self):
operation_name = cmd_translation[paths_type][self.cmd]
file_generator = FileGenerator(self._service, self._endpoint,
operation_name,
self.parameters['follow_symlinks'])
self.parameters['follow_symlinks'],
self._source_endpoint)
rev_generator = FileGenerator(self._service, self._endpoint, '',
self.parameters['follow_symlinks'])
taskinfo = [TaskInfo(src=files['src']['path'],
Expand Down Expand Up @@ -610,6 +638,8 @@ def __init__(self, session, cmd, parameters, usage):
self.parameters['dir_op'] = False
if 'follow_symlinks' not in parameters:
self.parameters['follow_symlinks'] = True
if 'source_region' not in parameters:
self.parameters['source_region'] = None
if self.cmd in ['sync', 'mb', 'rb']:
self.parameters['dir_op'] = True

Expand Down
3 changes: 3 additions & 0 deletions awscli/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def create_file(self, filename, contents, mtime=None):
os.makedirs(os.path.dirname(full_path))
with open(full_path, 'w') as f:
f.write(contents)
current_time = os.path.getmtime(full_path)
# Subtract a few years off the last modification date.
os.utime(full_path, (current_time, current_time - 100000000))
if mtime is not None:
os.utime(full_path, (mtime, mtime))
return full_path
Expand Down
21 changes: 15 additions & 6 deletions tests/integration/customizations/s3/test_filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def test_s3_file(self):
size=expected_file_size,
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='')
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)

expected_list = [file_info]
self.assertEqual(len(result_list), 1)
Expand All @@ -81,14 +83,18 @@ def test_s3_directory(self):
size=21,
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='')
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info2 = FileInfo(src=self.file1,
dest='text1.txt',
compare_key='text1.txt',
size=15,
last_update=result_list[1].last_update,
src_type='s3',
dest_type='local', operation_name='')
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)

expected_result = [file_info, file_info2]
self.assertEqual(len(result_list), 2)
Expand Down Expand Up @@ -117,7 +123,8 @@ def test_s3_delete_directory(self):
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service, endpoint=self.endpoint)
service=self.service, endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info2 = FileInfo(
src=self.file2,
dest='another_directory' + os.sep + 'text2.txt',
Expand All @@ -127,7 +134,8 @@ def test_s3_delete_directory(self):
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service,
endpoint=self.endpoint)
endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info3 = FileInfo(
src=self.file1,
dest='text1.txt',
Expand All @@ -137,7 +145,8 @@ def test_s3_delete_directory(self):
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service,
endpoint=self.endpoint)
endpoint=self.endpoint,
source_endpoint=self.endpoint)

expected_list = [file_info1, file_info2, file_info3]
self.assertEqual(len(result_list), 3)
Expand Down
Loading

0 comments on commit dc5c6f2

Please sign in to comment.