Skip to content

Commit

Permalink
Add InfoSetter class.
Browse files Browse the repository at this point in the history
This refactoring removes the necessity of passing
arguments through the ``FileGenerator`` class
in order for the ``FileInfo`` class to obtain
the arguments it requires to perform an operation.
  • Loading branch information
kyleknap committed Aug 12, 2014
1 parent dc5c6f2 commit c931e54
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 228 deletions.
30 changes: 19 additions & 11 deletions awscli/customizations/s3/filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from dateutil.parser import parse
from dateutil.tz import tzlocal

from awscli.customizations.s3.fileinfo import FileInfo
from awscli.customizations.s3.utils import find_bucket_key, get_file_stat
from awscli.customizations.s3.utils import BucketLister
from awscli.errorhandler import ClientError
Expand Down Expand Up @@ -46,6 +45,20 @@ def __init__(self, directory, filename):
super(FileDecodingError, self).__init__(self.error_message)


class FileBase(object):
def __init__(self, src, dest=None, compare_key=None, size=None,
last_update=None, src_type=None, dest_type=None,
operation_name=None):
self.src = src
self.dest = dest
self.compare_key = compare_key
self.size = size
self.last_update = last_update
self.src_type = src_type
self.dest_type = dest_type
self.operation_name = operation_name


class FileGenerator(object):
"""
This is a class the creates a generator to yield files based on information
Expand All @@ -55,12 +68,9 @@ class FileGenerator(object):
``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``.
"""
def __init__(self, service, endpoint, operation_name,
follow_symlinks=True, source_endpoint=None):
follow_symlinks=True):
self._service = service
self._endpoint = endpoint
self._source_endpoint = endpoint
if source_endpoint:
self._source_endpoint = source_endpoint
self.operation_name = operation_name
self.follow_symlinks = follow_symlinks

Expand Down Expand Up @@ -90,13 +100,11 @@ def call(self, files):
sep_table[dest_type])
else:
dest_path = dest['path']
yield FileInfo(src=src_path, dest=dest_path,
yield FileBase(src=src_path, dest=dest_path,
compare_key=compare_key, size=size,
last_update=last_update, src_type=src_type,
service=self._service, endpoint=self._endpoint,
dest_type=dest_type,
operation_name=self.operation_name,
source_endpoint=self._source_endpoint)
operation_name=self.operation_name)

def list_files(self, path, dir_op):
"""
Expand Down Expand Up @@ -195,7 +203,7 @@ def list_objects(self, s3_path, dir_op):
yield self._list_single_object(s3_path)
else:
operation = self._service.get_operation('ListObjects')
lister = BucketLister(operation, self._source_endpoint)
lister = BucketLister(operation, self._endpoint)
for key in lister.list_objects(bucket=bucket, prefix=prefix):
source_path, size, last_update = key
if size == 0 and source_path.endswith('/'):
Expand All @@ -221,7 +229,7 @@ def _list_single_object(self, s3_path):
operation = self._service.get_operation('HeadObject')
try:
response = operation.call(
self._source_endpoint, bucket=bucket, key=key)[1]
self._endpoint, bucket=bucket, key=key)[1]
except ClientError as e:
# We want to try to give a more helpful error message.
# This is what the customer is going to see so we want to
Expand Down
49 changes: 49 additions & 0 deletions awscli/customizations/s3/infosetter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2014 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from awscli.customizations.s3.fileinfo import FileInfo


class InfoSetter(object):
"""
This class takes a ``FileBase`` object's attributes and generates
a ``FileInfo`` object so that the operation can be performed.
"""
def __init__(self, service, endpoint, source_endpoint=None,
parameters = None):
self._service = service
self._endpoint = endpoint
self._source_endpoint = endpoint
if source_endpoint:
self._source_endpoint = source_endpoint
self._parameters = parameters

def call(self, files):
for file_base in files:
file_info = self.inject_info(file_base)
yield file_info

def inject_info(self, file_base):
file_info_attr = {}
file_info_attr['src'] = file_base.src
file_info_attr['dest'] = file_base.dest
file_info_attr['compare_key'] = file_base.compare_key
file_info_attr['size'] = file_base.size
file_info_attr['last_update'] = file_base.last_update
file_info_attr['src_type'] = file_base.src_type
file_info_attr['dest_type'] = file_base.dest_type
file_info_attr['operation_name'] = file_base.operation_name
file_info_attr['service'] = self._service
file_info_attr['endpoint'] = self._endpoint
file_info_attr['source_endpoint'] = self._source_endpoint
file_info_attr['parameters'] = self._parameters
return FileInfo(**file_info_attr)
16 changes: 13 additions & 3 deletions awscli/customizations/s3/subcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from awscli.customizations.commands import BasicCommand
from awscli.customizations.s3.comparator import Comparator
from awscli.customizations.s3.infosetter import InfoSetter
from awscli.customizations.s3.fileformat import FileFormat
from awscli.customizations.s3.filegenerator import FileGenerator
from awscli.customizations.s3.fileinfo import TaskInfo
Expand Down Expand Up @@ -486,6 +487,7 @@ def set_endpoints(self):
endpoint_url=self.parameters['endpoint_url'],
verify=self.parameters['verify_ssl']
)
self._source_endpoint = self._endpoint
if self.parameters['source_region']:
if self.parameters['paths_type'] == 's3s3':
self._source_endpoint = get_endpoint(
Expand All @@ -509,6 +511,8 @@ def create_instructions(self):
self.instructions.append('filters')
if self.cmd == 'sync':
self.instructions.append('comparator')
if self.cmd not in ['mb', 'rb']:
self.instructions.append('info_setter')
self.instructions.append('s3_handler')

def run(self):
Expand Down Expand Up @@ -551,17 +555,19 @@ def run(self):
'rb': 'remove_bucket'
}
operation_name = cmd_translation[paths_type][self.cmd]
file_generator = FileGenerator(self._service, self._endpoint,
file_generator = FileGenerator(self._service,
self._source_endpoint,
operation_name,
self.parameters['follow_symlinks'],
self._source_endpoint)
self.parameters['follow_symlinks'])
rev_generator = FileGenerator(self._service, self._endpoint, '',
self.parameters['follow_symlinks'])
taskinfo = [TaskInfo(src=files['src']['path'],
src_type='s3',
operation_name=operation_name,
service=self._service,
endpoint=self._endpoint)]
info_setter = InfoSetter(self._service, self._endpoint,
self._source_endpoint, self.parameters)
s3handler = S3Handler(self.session, self.parameters)

command_dict = {}
Expand All @@ -572,21 +578,25 @@ def run(self):
'filters': [create_filter(self.parameters),
create_filter(self.parameters)],
'comparator': [Comparator(self.parameters)],
'info_setter': [info_setter],
's3_handler': [s3handler]}
elif self.cmd == 'cp':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'info_setter': [info_setter],
's3_handler': [s3handler]}
elif self.cmd == 'rm':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'info_setter': [info_setter],
's3_handler': [s3handler]}
elif self.cmd == 'mv':
command_dict = {'setup': [files],
'file_generator': [file_generator],
'filters': [create_filter(self.parameters)],
'info_setter': [info_setter],
's3_handler': [s3handler]}
elif self.cmd == 'mb':
command_dict = {'setup': [taskinfo],
Expand Down
47 changes: 16 additions & 31 deletions tests/integration/customizations/s3/test_filegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

import botocore.session
from awscli import EnvironmentVariables
from awscli.customizations.s3.filegenerator import FileGenerator
from awscli.customizations.s3.fileinfo import FileInfo
from awscli.customizations.s3.filegenerator import FileGenerator, FileBase
from tests.unit.customizations.s3 import make_s3_files, s3_cleanup, \
compare_files

Expand Down Expand Up @@ -52,16 +51,14 @@ def test_s3_file(self):
result_list = list(
FileGenerator(self.service, self.endpoint, '').call(
input_s3_file))
file_info = FileInfo(src=self.file1, dest='text1.txt',
file_base = FileBase(src=self.file1, dest='text1.txt',
compare_key='text1.txt',
size=expected_file_size,
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)
dest_type='local', operation_name='')

expected_list = [file_info]
expected_list = [file_base]
self.assertEqual(len(result_list), 1)
compare_files(self, result_list[0], expected_list[0])

Expand All @@ -77,26 +74,22 @@ def test_s3_directory(self):
result_list = list(
FileGenerator(self.service, self.endpoint, '').call(
input_s3_file))
file_info = FileInfo(src=self.file2,
file_base = FileBase(src=self.file2,
dest='another_directory' + os.sep + 'text2.txt',
compare_key='another_directory/text2.txt',
size=21,
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info2 = FileInfo(src=self.file1,
dest_type='local', operation_name='')
file_base2 = FileBase(src=self.file1,
dest='text1.txt',
compare_key='text1.txt',
size=15,
last_update=result_list[1].last_update,
src_type='s3',
dest_type='local', operation_name='',
endpoint=self.endpoint,
source_endpoint=self.endpoint)
dest_type='local', operation_name='')

expected_result = [file_info, file_info2]
expected_result = [file_base, file_base2]
self.assertEqual(len(result_list), 2)
compare_files(self, result_list[0], expected_result[0])
compare_files(self, result_list[1], expected_result[1])
Expand All @@ -115,40 +108,32 @@ def test_s3_delete_directory(self):
'delete').call(
input_s3_file))

file_info1 = FileInfo(
file_base1 = FileBase(
src=self.bucket + '/another_directory/',
dest='another_directory' + os.sep,
compare_key='another_directory/',
size=0,
last_update=result_list[0].last_update,
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service, endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info2 = FileInfo(
dest_type='local', operation_name='delete')
file_base2 = FileBase(
src=self.file2,
dest='another_directory' + os.sep + 'text2.txt',
compare_key='another_directory/text2.txt',
size=21,
last_update=result_list[1].last_update,
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service,
endpoint=self.endpoint,
source_endpoint=self.endpoint)
file_info3 = FileInfo(
dest_type='local', operation_name='delete')
file_base3 = FileBase(
src=self.file1,
dest='text1.txt',
compare_key='text1.txt',
size=15,
last_update=result_list[2].last_update,
src_type='s3',
dest_type='local', operation_name='delete',
service=self.service,
endpoint=self.endpoint,
source_endpoint=self.endpoint)
dest_type='local', operation_name='delete')

expected_list = [file_info1, file_info2, file_info3]
expected_list = [file_base1, file_base2, file_base3]
self.assertEqual(len(result_list), 3)
compare_files(self, result_list[0], expected_list[0])
compare_files(self, result_list[1], expected_list[1])
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/customizations/s3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def s3_cleanup(bucket, session, key1='text1.txt', key2='text2.txt'):

def compare_files(self, result_file, ref_file):
"""
Ensures that the FileInfo's properties are what they
Ensures that the FileBase's properties are what they
are suppose to be.
"""
self.assertEqual(result_file.src, ref_file.src)
Expand All @@ -161,16 +161,6 @@ def compare_files(self, result_file, ref_file):
self.assertEqual(result_file.src_type, ref_file.src_type)
self.assertEqual(result_file.dest_type, ref_file.dest_type)
self.assertEqual(result_file.operation_name, ref_file.operation_name)
compare_endpoints(self, result_file.endpoint, ref_file.endpoint)
compare_endpoints(self, result_file.source_endpoint,
ref_file.source_endpoint)


def compare_endpoints(self, endpoint, ref_endpoint):
self.assertEqual(endpoint.region_name, ref_endpoint.region_name)
if getattr(endpoint, 'endpoint_url', None):
self.assertEqual(endpoint.endpoint_url, ref_endpoint.endpoint_url)
self.assertEqual(endpoint.verify, ref_endpoint.verify)


def list_contents(bucket, session):
Expand Down
Loading

0 comments on commit c931e54

Please sign in to comment.