From 74b6be138ac675ec01eae48a9dc2cb7b9d453dda Mon Sep 17 00:00:00 2001 From: kyleknap Date: Mon, 25 Aug 2014 10:38:02 -0700 Subject: [PATCH] Added the ability to stream data using ``cp``. This feature enables users to stream from stdin to s3 or from s3 to stdout. Streaming large files is both multithreaded and uses multipart transfers. The streaming feature is limited to single file ``cp`` commands. --- CHANGELOG.rst | 3 + awscli/customizations/s3/constants.py | 1 + awscli/customizations/s3/executor.py | 36 ++-- awscli/customizations/s3/filegenerator.py | 23 ++- awscli/customizations/s3/fileinfo.py | 110 ++++++++---- awscli/customizations/s3/fileinfobuilder.py | 1 + awscli/customizations/s3/s3handler.py | 156 +++++++++++++--- awscli/customizations/s3/subcommands.py | 25 ++- awscli/customizations/s3/tasks.py | 73 ++++++-- awscli/customizations/s3/utils.py | 20 ++- awscli/examples/s3/cp.rst | 12 ++ awscli/testutils.py | 15 +- .../customizations/s3/test_plugin.py | 101 ++++++++++- tests/unit/customizations/s3/__init__.py | 25 ++- tests/unit/customizations/s3/test_executor.py | 38 ++-- .../customizations/s3/test_filegenerator.py | 18 ++ tests/unit/customizations/s3/test_fileinfo.py | 14 ++ .../customizations/s3/test_fileinfobuilder.py | 3 +- .../unit/customizations/s3/test_s3handler.py | 167 +++++++++++++++++- .../customizations/s3/test_subcommands.py | 55 ++++-- tests/unit/customizations/s3/test_tasks.py | 113 +++++++++++- tests/unit/test_completer.py | 5 +- 22 files changed, 871 insertions(+), 143 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a1de4f7e2a2fb..c0b5afe6645f7 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,9 @@ CHANGELOG Next Release (TBD) ================== +* feature:``aws s3 cp``: Added ability to upload local + file streams from standard input to s3 and download s3 + objects as local file streams to standard output. * feature:Page Size: Add a ``--page-size`` option, that controls page size when perfoming an operation that uses pagination. diff --git a/awscli/customizations/s3/constants.py b/awscli/customizations/s3/constants.py index d0877eed26b24..7c0b7c4fcbc3d 100644 --- a/awscli/customizations/s3/constants.py +++ b/awscli/customizations/s3/constants.py @@ -18,3 +18,4 @@ MAX_SINGLE_UPLOAD_SIZE = 5 * (1024 ** 3) MAX_UPLOAD_SIZE = 5 * (1024 ** 4) MAX_QUEUE_SIZE = 1000 +STREAM_INPUT_TIMEOUT = 0.1 diff --git a/awscli/customizations/s3/executor.py b/awscli/customizations/s3/executor.py index 872f181ef055b..c6866246022a4 100644 --- a/awscli/customizations/s3/executor.py +++ b/awscli/customizations/s3/executor.py @@ -15,8 +15,8 @@ import sys import threading -from awscli.customizations.s3.utils import uni_print, \ - IORequest, IOCloseRequest, StablePriorityQueue +from awscli.customizations.s3.utils import uni_print, bytes_print, \ + IORequest, IOCloseRequest, StablePriorityQueue from awscli.customizations.s3.tasks import OrderableTask @@ -50,8 +50,7 @@ def __init__(self, num_threads, result_queue, self.quiet = quiet self.threads_list = [] self.write_queue = write_queue - self.print_thread = PrintThread(self.result_queue, - self.quiet) + self.print_thread = PrintThread(self.result_queue, self.quiet) self.print_thread.daemon = True self.io_thread = IOWriterThread(self.write_queue) @@ -153,23 +152,28 @@ def run(self): self._cleanup() return elif isinstance(task, IORequest): - filename, offset, data = task - fileobj = self.fd_descriptor_cache.get(filename) - if fileobj is None: - fileobj = open(filename, 'rb+') - self.fd_descriptor_cache[filename] = fileobj - fileobj.seek(offset) + filename, offset, data, is_stream = task + if is_stream: + fileobj = sys.stdout + bytes_print(data) + else: + fileobj = self.fd_descriptor_cache.get(filename) + if fileobj is None: + fileobj = open(filename, 'rb+') + self.fd_descriptor_cache[filename] = fileobj + fileobj.seek(offset) + fileobj.write(data) LOGGER.debug("Writing data to: %s, offset: %s", filename, offset) - fileobj.write(data) fileobj.flush() elif isinstance(task, IOCloseRequest): LOGGER.debug("IOCloseRequest received for %s, closing file.", task.filename) - fileobj = self.fd_descriptor_cache.get(task.filename) - if fileobj is not None: - fileobj.close() - del self.fd_descriptor_cache[task.filename] + if not task.is_stream: + fileobj = self.fd_descriptor_cache.get(task.filename) + if fileobj is not None: + fileobj.close() + del self.fd_descriptor_cache[task.filename] def _cleanup(self): for fileobj in self.fd_descriptor_cache.values(): @@ -237,7 +241,7 @@ def __init__(self, result_queue, quiet): self._lock = threading.Lock() self._needs_newline = False - self._total_parts = 0 + self._total_parts = '...' self._total_files = '...' # This is a public attribute that clients can inspect to determine diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index b53be0c459397..f756bb4b3c624 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -95,7 +95,7 @@ def __init__(self, directory, filename): class FileStat(object): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, - operation_name=None): + operation_name=None, is_stream=False): self.src = src self.dest = dest self.compare_key = compare_key @@ -104,6 +104,7 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.src_type = src_type self.dest_type = dest_type self.operation_name = operation_name + self.is_stream = is_stream class FileGenerator(object): @@ -115,7 +116,8 @@ class FileGenerator(object): ``FileInfo`` objects to send to a ``Comparator`` or ``S3Handler``. """ def __init__(self, service, endpoint, operation_name, - follow_symlinks=True, page_size=None, result_queue=None): + follow_symlinks=True, page_size=None, result_queue=None, + is_stream=False): self._service = service self._endpoint = endpoint self.operation_name = operation_name @@ -124,6 +126,7 @@ def __init__(self, service, endpoint, operation_name, self.result_queue = result_queue if not result_queue: self.result_queue = queue.Queue() + self.is_stream = is_stream def call(self, files): """ @@ -135,7 +138,11 @@ def call(self, files): dest = files['dest'] src_type = src['type'] dest_type = dest['type'] - function_table = {'s3': self.list_objects, 'local': self.list_files} + function_table = {'s3': self.list_objects} + if self.is_stream: + function_table['local'] = self.list_local_file_stream + else: + function_table['local'] = self.list_files sep_table = {'s3': '/', 'local': os.sep} source = src['path'] file_list = function_table[src_type](source, files['dir_op']) @@ -155,7 +162,15 @@ def call(self, files): compare_key=compare_key, size=size, last_update=last_update, src_type=src_type, dest_type=dest_type, - operation_name=self.operation_name) + operation_name=self.operation_name, + is_stream=self.is_stream) + + def list_local_file_stream(self, path, dir_op): + """ + Yield some dummy values for a local file stream since it does not + actually have a file. + """ + yield '-', 0, None def list_files(self, path, dir_op): """ diff --git a/awscli/customizations/s3/fileinfo.py b/awscli/customizations/s3/fileinfo.py index fe482e64d13b2..407fcc8837ada 100644 --- a/awscli/customizations/s3/fileinfo.py +++ b/awscli/customizations/s3/fileinfo.py @@ -11,7 +11,7 @@ from botocore.compat import quote from awscli.customizations.s3.utils import find_bucket_key, \ check_etag, check_error, operate, uni_print, \ - guess_content_type, MD5Error + guess_content_type, MD5Error, bytes_print class CreateDirectoryError(Exception): @@ -26,7 +26,7 @@ def read_file(filename): return in_file.read() -def save_file(filename, response_data, last_update): +def save_file(filename, response_data, last_update, is_stream=False): """ This writes to the file upon downloading. It reads the data in the response. Makes a new directory if needed and then writes the @@ -35,31 +35,57 @@ def save_file(filename, response_data, last_update): """ body = response_data['Body'] etag = response_data['ETag'][1:-1] - d = os.path.dirname(filename) - try: - if not os.path.exists(d): - os.makedirs(d) - except OSError as e: - if not e.errno == errno.EEXIST: - raise CreateDirectoryError( - "Could not create directory %s: %s" % (d, e)) + if not is_stream: + d = os.path.dirname(filename) + try: + if not os.path.exists(d): + os.makedirs(d) + except OSError as e: + if not e.errno == errno.EEXIST: + raise CreateDirectoryError( + "Could not create directory %s: %s" % (d, e)) md5 = hashlib.md5() file_chunks = iter(partial(body.read, 1024 * 1024), b'') - with open(filename, 'wb') as out_file: - if not _is_multipart_etag(etag): - for chunk in file_chunks: - md5.update(chunk) - out_file.write(chunk) - else: - for chunk in file_chunks: - out_file.write(chunk) + if is_stream: + # Need to save the data to be able to check the etag for a stream + # becuase once the data is written to the stream there is no + # undoing it. + payload = write_to_file(None, etag, md5, file_chunks, True) + else: + with open(filename, 'wb') as out_file: + write_to_file(out_file, etag, md5, file_chunks) + if not _is_multipart_etag(etag): if etag != md5.hexdigest(): - os.remove(filename) + if not is_stream: + os.remove(filename) raise MD5Error(filename) - last_update_tuple = last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + + if not is_stream: + last_update_tuple = last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + else: + # Now write the output to stdout since the md5 is correct. + bytes_print(payload) + sys.stdout.flush() + + +def write_to_file(out_file, etag, md5, file_chunks, is_stream=False): + """ + Updates the etag for each file chunk. It will write to the file if it a + file but if it is a stream it will return a byte string to be later + written to a stream. + """ + body = b'' + for chunk in file_chunks: + if not _is_multipart_etag(etag): + md5.update(chunk) + if is_stream: + body += chunk + else: + out_file.write(chunk) + return body def _is_multipart_etag(etag): @@ -140,7 +166,7 @@ class FileInfo(TaskInfo): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, operation_name=None, service=None, endpoint=None, - parameters=None, source_endpoint=None): + parameters=None, source_endpoint=None, is_stream=False): super(FileInfo, self).__init__(src, src_type=src_type, operation_name=operation_name, service=service, @@ -157,6 +183,7 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.parameters = {'acl': None, 'sse': None} self.source_endpoint = source_endpoint + self.is_stream = is_stream def _permission_to_param(self, permission): if permission == 'read': @@ -204,24 +231,30 @@ def _handle_object_params(self, params): if self.parameters['expires']: params['expires'] = self.parameters['expires'][0] - def upload(self): + def upload(self, payload=None): """ Redirects the file to the multipart upload function if the file is large. If it is small enough, it puts the file as an object in s3. """ - with open(self.src, 'rb') as body: - bucket, key = find_bucket_key(self.dest) - params = { - 'endpoint': self.endpoint, - 'bucket': bucket, - 'key': key, - 'body': body, - } - self._handle_object_params(params) - response_data, http = operate(self.service, 'PutObject', params) - etag = response_data['ETag'][1:-1] - body.seek(0) - check_etag(etag, body) + if payload: + self._handle_upload(payload) + else: + with open(self.src, 'rb') as body: + self._handle_upload(body) + + def _handle_upload(self, body): + bucket, key = find_bucket_key(self.dest) + params = { + 'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key, + 'body': body, + } + self._handle_object_params(params) + response_data, http = operate(self.service, 'PutObject', params) + etag = response_data['ETag'][1:-1] + body.seek(0) + check_etag(etag, body) def _inject_content_type(self, params, filename): # Add a content type param if we can guess the type. @@ -237,7 +270,8 @@ def download(self): bucket, key = find_bucket_key(self.src) params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key} response_data, http = operate(self.service, 'GetObject', params) - save_file(self.dest, response_data, self.last_update) + save_file(self.dest, response_data, self.last_update, + self.is_stream) def copy(self): """ diff --git a/awscli/customizations/s3/fileinfobuilder.py b/awscli/customizations/s3/fileinfobuilder.py index 8bc2042615ef8..b220565b61cc0 100644 --- a/awscli/customizations/s3/fileinfobuilder.py +++ b/awscli/customizations/s3/fileinfobuilder.py @@ -42,6 +42,7 @@ def _inject_info(self, file_base): file_info_attr['src_type'] = file_base.src_type file_info_attr['dest_type'] = file_base.dest_type file_info_attr['operation_name'] = file_base.operation_name + file_info_attr['is_stream'] = file_base.is_stream file_info_attr['service'] = self._service file_info_attr['endpoint'] = self._endpoint file_info_attr['source_endpoint'] = self._source_endpoint diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index 91f701bbd83d2..46dc2e6c897db 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -14,10 +14,13 @@ import logging import math import os +import six from six.moves import queue +import sys +import time from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \ - NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE + NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE, STREAM_INPUT_TIMEOUT from awscli.customizations.s3.utils import find_chunksize, \ operate, find_bucket_key, relative_path, PrintTask, create_warning from awscli.customizations.s3.executor import Executor @@ -53,16 +56,24 @@ def __init__(self, session, params, result_queue=None, 'content_type': None, 'cache_control': None, 'content_disposition': None, 'content_encoding': None, 'content_language': None, 'expires': None, - 'grants': None} + 'grants': None, 'is_stream': False, 'paths_type': None, + 'expected_size': None} self.params['region'] = params['region'] for key in self.params.keys(): if key in params: self.params[key] = params[key] self.multi_threshold = multi_threshold self.chunksize = chunksize + self._max_executer_queue_size = MAX_QUEUE_SIZE + if self.params['is_stream']: + # This ensures that at most the number of multipart chunks + # waiting in the executor queue from a stream read in from stdin + # is the same as the number of threads needed to upload it. + self._max_executer_queue_size = NUM_THREADS self.executor = Executor( num_threads=NUM_THREADS, result_queue=self.result_queue, - quiet=self.params['quiet'], max_queue_size=MAX_QUEUE_SIZE, + quiet=self.params['quiet'], + max_queue_size=self._max_executer_queue_size, write_queue=self.write_queue ) self._multipart_uploads = [] @@ -162,7 +173,15 @@ def _enqueue_tasks(self, files): total_parts = 0 for filename in files: num_uploads = 1 - is_multipart_task = self._is_multipart_task(filename) + # If uploading stream, it is required to read from the stream + # to determine if the stream needs to be multipart uploaded. + payload = None + if getattr(filename, 'is_stream', False) and \ + filename.operation_name == 'upload': + payload, is_multipart_task = \ + self._pull_from_stream(self.multi_threshold) + else: + is_multipart_task = self._is_multipart_task(filename) too_large = False if hasattr(filename, 'size'): too_large = filename.size > MAX_UPLOAD_SIZE @@ -178,17 +197,42 @@ def _enqueue_tasks(self, files): # fact that it's transferring a file rather than # the specific part tasks required to perform the # transfer. - num_uploads = self._enqueue_multipart_tasks(filename) + num_uploads = self._enqueue_multipart_tasks(filename, payload) else: task = tasks.BasicTask( session=self.session, filename=filename, parameters=self.params, - result_queue=self.result_queue) + result_queue=self.result_queue, + payload=payload) self.executor.submit(task) total_files += 1 total_parts += num_uploads return total_files, total_parts + def _pull_from_stream(self, initial_amount_requested): + size = 0 + amount_requested = initial_amount_requested + total_retries = 0 + payload = b'' + stream_filein = sys.stdin + if six.PY3: + stream_filein = sys.stdin.buffer + while True: + payload_chunk = stream_filein.read(amount_requested) + payload_chunk_size = len(payload_chunk) + payload += payload_chunk + size += payload_chunk_size + amount_requested -= payload_chunk_size + if payload_chunk_size == 0: + time.sleep(STREAM_INPUT_TIMEOUT) + total_retries += 1 + else: + total_retries = 0 + if amount_requested == 0 or total_retries == 5: + break + payload_file = six.BytesIO(payload) + return payload_file, size == initial_amount_requested + def _is_multipart_task(self, filename): # First we need to determine if it's an operation that even # qualifies for multipart upload. @@ -203,10 +247,11 @@ def _is_multipart_task(self, filename): else: return False - def _enqueue_multipart_tasks(self, filename): + def _enqueue_multipart_tasks(self, filename, payload=None): num_uploads = 1 if filename.operation_name == 'upload': - num_uploads = self._enqueue_multipart_upload_tasks(filename) + num_uploads = self._enqueue_multipart_upload_tasks(filename, + payload=payload) elif filename.operation_name == 'move': if filename.src_type == 'local' and filename.dest_type == 's3': num_uploads = self._enqueue_multipart_upload_tasks( @@ -231,9 +276,12 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): chunksize = find_chunksize(filename.size, self.chunksize) num_downloads = int(filename.size / chunksize) context = tasks.MultipartDownloadContext(num_downloads) - create_file_task = tasks.CreateLocalFileTask(context=context, - filename=filename) - self.executor.submit(create_file_task) + if not filename.is_stream: + create_file_task = tasks.CreateLocalFileTask(context=context, + filename=filename) + self.executor.submit(create_file_task) + else: + context.announce_file_created() for i in range(num_downloads): task = tasks.DownloadPartTask( part_number=i, chunk_size=chunksize, @@ -252,17 +300,27 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): return num_downloads def _enqueue_multipart_upload_tasks(self, filename, - remove_local_file=False): + remove_local_file=False, + payload=None): # First we need to create a CreateMultipartUpload task, # then create UploadTask objects for each of the parts. # And finally enqueue a CompleteMultipartUploadTask. - chunksize = find_chunksize(filename.size, self.chunksize) - num_uploads = int(math.ceil(filename.size / - float(chunksize))) + chunksize = self.chunksize + if not filename.is_stream: + chunksize = find_chunksize(filename.size, self.chunksize) + num_uploads = int(math.ceil(filename.size / + float(chunksize))) + else: + if self.params['expected_size']: + chunksize = find_chunksize(int(self.params['expected_size']), + self.chunksize) + num_uploads = '...' upload_context = self._enqueue_upload_start_task( - chunksize, num_uploads, filename) - self._enqueue_upload_tasks( - num_uploads, chunksize, upload_context, filename, tasks.UploadPartTask) + chunksize, num_uploads, filename, payload) + num_uploads = self._enqueue_upload_tasks( + num_uploads, chunksize, upload_context, + filename, tasks.UploadPartTask + ) self._enqueue_upload_end_task(filename, upload_context) if remove_local_file: remove_task = tasks.RemoveFileTask(local_filename=filename.src, @@ -276,8 +334,7 @@ def _enqueue_multipart_copy_tasks(self, filename, num_uploads = int(math.ceil(filename.size / float(chunksize))) upload_context = self._enqueue_upload_start_task( chunksize, num_uploads, filename) - self._enqueue_upload_tasks( - num_uploads, chunksize, upload_context, filename, tasks.CopyPartTask) + self._enqueue_upload_tasks(num_uploads, chunksize, upload_context, filename, tasks.CopyPartTask) self._enqueue_upload_end_task(filename, upload_context) if remove_remote_file: remove_task = tasks.RemoveRemoteObjectTask( @@ -285,7 +342,8 @@ def _enqueue_multipart_copy_tasks(self, filename, self.executor.submit(remove_task) return num_uploads - def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): + def _enqueue_upload_start_task(self, chunksize, num_uploads, filename, + payload=None): upload_context = tasks.MultipartUploadContext( expected_parts=num_uploads) create_multipart_upload_task = tasks.CreateMultipartUploadTask( @@ -293,16 +351,56 @@ def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): parameters=self.params, result_queue=self.result_queue, upload_context=upload_context) self.executor.submit(create_multipart_upload_task) + if filename.is_stream and filename.operation_name == 'upload': + # Upload the part that was intially pulled from the stream. + self._enqueue_upload_single_part_task( + part_number=1, chunk_size=chunksize, + upload_context=upload_context, filename=filename, + task_class=tasks.UploadPartTask, payload=payload + ) return upload_context - def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, filename, - task_class): - for i in range(1, (num_uploads + 1)): - task = task_class( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, upload_context=upload_context, - filename=filename) - self.executor.submit(task) + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): + if filename.is_stream and filename.operation_name == 'upload': + # The previous upload occured right after the multipart + # upload started for a stream. + num_uploads = 1 + while True: + payload, is_remaining = self._pull_from_stream(chunksize) + self._enqueue_upload_single_part_task( + part_number=num_uploads+1, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class, + payload=payload + ) + num_uploads += 1 + if not is_remaining: + break + upload_context.announce_total_parts(num_uploads) + else: + for i in range(1, (num_uploads + 1)): + self._enqueue_upload_single_part_task( + part_number=i, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class + ) + return num_uploads + + def _enqueue_upload_single_part_task(self, part_number, chunk_size, + upload_context, filename, task_class, + payload=None): + kwargs = {'part_number': part_number, 'chunk_size': chunk_size, + 'result_queue': self.result_queue, + 'upload_context': upload_context, 'filename': filename} + if payload: + kwargs['payload'] = payload + task = task_class(**kwargs) + self.executor.submit(task) def _enqueue_upload_end_task(self, filename, upload_context): complete_multipart_upload_task = tasks.CompleteMultipartUploadTask( diff --git a/awscli/customizations/s3/subcommands.py b/awscli/customizations/s3/subcommands.py index 6ec6856294d9b..0df2f8c1e94fd 100644 --- a/awscli/customizations/s3/subcommands.py +++ b/awscli/customizations/s3/subcommands.py @@ -206,6 +206,15 @@ 'The object key name to use when ' 'a 4XX class error occurs.')} +EXPECTED_SIZE = {'name': 'expected-size', + 'help_text': ( + 'This argument specifies the expected size of a stream ' + 'in terms of bytes. Note that this argument is needed ' + 'only when a stream is being uploaded to s3 and the size ' + 'is larger than 5GB. Failure to include this argument ' + 'under these conditions may result in a failed upload. ' + 'due to too many parts in upload.')} + TRANSFER_ARGS = [DRYRUN, QUIET, RECURSIVE, INCLUDE, EXCLUDE, ACL, FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE, SSE, STORAGE_CLASS, GRANTS, WEBSITE_REDIRECT, CONTENT_TYPE, @@ -413,7 +422,7 @@ class CpCommand(S3TransferCommand): USAGE = " or " \ "or " ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True, - 'synopsis': USAGE}] + TRANSFER_ARGS + 'synopsis': USAGE}] + TRANSFER_ARGS + [EXPECTED_SIZE] EXAMPLES = BasicCommand.FROM_FILE('s3/cp.rst') @@ -566,7 +575,8 @@ def run(self): operation_name, self.parameters['follow_symlinks'], self.parameters['page_size'], - result_queue=result_queue) + result_queue=result_queue, + is_stream=self.parameters['is_stream']) rev_generator = FileGenerator(self._service, self._endpoint, '', self.parameters['follow_symlinks'], self.parameters['page_size'], @@ -683,8 +693,19 @@ def add_paths(self, paths): self.parameters['dest'] = paths[1] elif len(paths) == 1: self.parameters['dest'] = paths[0] + self._validate_streaming_paths() self._validate_path_args() + def _validate_streaming_paths(self): + self.parameters['is_stream'] = False + if self.parameters['src'] == '-' or self.parameters['dest'] == '-': + self.parameters['is_stream'] = True + self.parameters['dir_op'] = False + self.parameters['quiet'] = True + if self.parameters['is_stream'] and self.cmd != 'cp': + raise ValueError("Streaming currently is only compatible with " + "single file cp commands") + def _validate_path_args(self): # If we're using a mv command, you can't copy the object onto itself. params = self.parameters diff --git a/awscli/customizations/s3/tasks.py b/awscli/customizations/s3/tasks.py index 37089c42a2b99..07e6547bb619a 100644 --- a/awscli/customizations/s3/tasks.py +++ b/awscli/customizations/s3/tasks.py @@ -63,7 +63,8 @@ class BasicTask(OrderableTask): attributes like ``session`` object in order for the filename to perform its designated operation. """ - def __init__(self, session, filename, parameters, result_queue): + def __init__(self, session, filename, parameters, + result_queue, payload=None): self.session = session self.service = self.session.get_service('s3') @@ -72,6 +73,7 @@ def __init__(self, session, filename, parameters, result_queue): self.parameters = parameters self.result_queue = result_queue + self.payload = payload def __call__(self): self._execute_task(attempts=3) @@ -84,9 +86,12 @@ def _execute_task(self, attempts, last_error=''): error_message=last_error) return filename = self.filename + kwargs = {} + if self.payload: + kwargs['payload'] = self.payload try: if not self.parameters['dryrun']: - getattr(filename, filename.operation_name)() + getattr(filename, filename.operation_name)(**kwargs) except requests.ConnectionError as e: connect_error = str(e) LOGGER.debug("%s %s failure: %s", @@ -195,13 +200,14 @@ class UploadPartTask(OrderableTask): complete the multipart upload initiated by the ``FileInfo`` object. """ - def __init__(self, part_number, chunk_size, - result_queue, upload_context, filename): + def __init__(self, part_number, chunk_size, result_queue, upload_context, + filename, payload=None): self._result_queue = result_queue self._upload_context = upload_context self._part_number = part_number self._chunk_size = chunk_size self._filename = filename + self._payload = payload def _read_part(self): actual_filename = self._filename.src @@ -216,9 +222,13 @@ def __call__(self): LOGGER.debug("Waiting for upload id.") upload_id = self._upload_context.wait_for_upload_id() bucket, key = find_bucket_key(self._filename.dest) - total = int(math.ceil( - self._filename.size/float(self._chunk_size))) - body = self._read_part() + if self._filename.is_stream: + body = self._payload + total = self._upload_context.expected_parts + else: + total = int(math.ceil( + self._filename.size/float(self._chunk_size))) + body = self._read_part() params = {'endpoint': self._filename.endpoint, 'bucket': bucket, 'key': key, 'part_number': self._part_number, @@ -298,14 +308,17 @@ def __call__(self): # 3) Queue an IO request to the IO thread letting it know we're # done with the file. self._context.wait_for_completion() - last_update_tuple = self._filename.last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(self._filename.dest, (int(mod_timestamp), int(mod_timestamp))) + if not self._filename.is_stream: + last_update_tuple = self._filename.last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(self._filename.dest, + (int(mod_timestamp), int(mod_timestamp))) message = print_operation(self._filename, False, self._parameters['dryrun']) print_task = {'message': message, 'error': False} self._result_queue.put(PrintTask(**print_task)) - self._io_queue.put(IOCloseRequest(self._filename.dest)) + self._io_queue.put(IOCloseRequest(self._filename.dest, + self._filename.is_stream)) class DownloadPartTask(OrderableTask): @@ -393,16 +406,23 @@ def _queue_writes(self, body): body.set_socket_timeout(self.READ_TIMEOUT) amount_read = 0 current = body.read(iterate_chunk_size) + if self._filename.is_stream: + self._context.wait_for_turn(self._part_number) while current: offset = self._part_number * self._chunk_size + amount_read LOGGER.debug("Submitting IORequest to write queue.") - self._io_queue.put(IORequest(self._filename.dest, offset, current)) + self._io_queue.put( + IORequest(self._filename.dest, offset, current, + self._filename.is_stream) + ) LOGGER.debug("Request successfully submitted.") amount_read += len(current) current = body.read(iterate_chunk_size) # Change log message. LOGGER.debug("Done queueing writes for part number %s to file: %s", self._part_number, self._filename.dest) + if self._filename.is_stream: + self._context.done_with_turn() class CreateMultipartUploadTask(BasicTask): @@ -530,7 +550,7 @@ class MultipartUploadContext(object): _CANCELLED = '_CANCELLED' _COMPLETED = '_COMPLETED' - def __init__(self, expected_parts): + def __init__(self, expected_parts='...'): self._upload_id = None self._expected_parts = expected_parts self._parts = [] @@ -540,6 +560,10 @@ def __init__(self, expected_parts): self._upload_complete_condition = threading.Condition(self._lock) self._state = self._UNSTARTED + @property + def expected_parts(self): + return self._expected_parts + def announce_upload_id(self, upload_id): with self._upload_id_condition: self._upload_id = upload_id @@ -551,9 +575,15 @@ def announce_finished_part(self, etag, part_number): self._parts.append({'ETag': etag, 'PartNumber': part_number}) self._parts_condition.notifyAll() + def announce_total_parts(self, total_parts): + with self._parts_condition: + self._expected_parts = total_parts + self._parts_condition.notifyAll() + def wait_for_parts_to_finish(self): with self._parts_condition: - while len(self._parts) < self._expected_parts: + while self._expected_parts == '...' or \ + len(self._parts) < self._expected_parts: if self._state == self._CANCELLED: raise UploadCancelledError("Upload has been cancelled.") self._parts_condition.wait(timeout=1) @@ -653,9 +683,11 @@ def __init__(self, num_parts, lock=None): lock = threading.Lock() self._lock = lock self._created_condition = threading.Condition(self._lock) + self._submit_write_condition = threading.Condition(self._lock) self._completed_condition = threading.Condition(self._lock) self._state = self._STATES['UNSTARTED'] self._finished_parts = set() + self._current_stream_part_number = 0 def announce_completed_part(self, part_number): with self._completed_condition: @@ -685,6 +717,19 @@ def wait_for_completion(self): "Download has been cancelled.") self._completed_condition.wait(timeout=1) + def wait_for_turn(self, part_number): + with self._submit_write_condition: + while self._current_stream_part_number != part_number: + if self._state == self._STATES['CANCELLED']: + raise DownloadCancelledError( + "Download has been cancelled.") + self._submit_write_condition.wait(timeout=0.2) + + def done_with_turn(self): + with self._submit_write_condition: + self._current_stream_part_number += 1 + self._submit_write_condition.notifyAll() + def cancel(self): with self._lock: self._state = self._STATES['CANCELLED'] diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index eea51a5fbdbc3..76cf68b39bd68 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -243,6 +243,21 @@ def uni_print(statement): sys.stdout.write(statement.encode('utf-8')) +def bytes_print(statement): + """ + This function is used to properly write bytes to standard out. + """ + if PY3: + if getattr(sys.stdout, 'buffer', None): + sys.stdout.buffer.write(statement) + else: + # If it is not possible to write to the standard out buffer. + # The next best option is to decode and write to standard out. + sys.stdout.write(statement.decode('utf-8')) + else: + sys.stdout.write(statement) + + def guess_content_type(filename): """Given a filename, guess it's content type. @@ -396,7 +411,8 @@ def __new__(cls, message, error=False, total_parts=None, warning=None): warning) -IORequest = namedtuple('IORequest', ['filename', 'offset', 'data']) +IORequest = namedtuple('IORequest', + ['filename', 'offset', 'data', 'is_stream']) # Used to signal that IO for the filename is finished, and that # any associated resources may be cleaned up. -IOCloseRequest = namedtuple('IOCloseRequest', ['filename']) +IOCloseRequest = namedtuple('IOCloseRequest', ['filename', 'is_stream']) diff --git a/awscli/examples/s3/cp.rst b/awscli/examples/s3/cp.rst index 6bdf25dc0dc12..1fe488bc77519 100644 --- a/awscli/examples/s3/cp.rst +++ b/awscli/examples/s3/cp.rst @@ -101,3 +101,15 @@ Output:: upload: file.txt to s3://mybucket/file.txt +**Uploading a local file stream to S3** + +The following ``cp`` command uploads a local file stream from standard input to a specified bucket and key:: + + aws s3 cp - s3://mybucket/stream.txt + + +**Downloading a S3 object as a local file stream** + +The following ``cp`` command downloads a S3 object locally as a stream to standard output:: + + aws s3 cp s3://mybucket/stream.txt - diff --git a/awscli/testutils.py b/awscli/testutils.py index b6f4bd2abcc01..2f4b55e018c1b 100644 --- a/awscli/testutils.py +++ b/awscli/testutils.py @@ -395,7 +395,7 @@ def _escape_quotes(command): def aws(command, collect_memory=False, env_vars=None, - wait_for_finish=True): + wait_for_finish=True, input_data=None): """Run an aws command. This help function abstracts the differences of running the "aws" @@ -421,7 +421,7 @@ def aws(command, collect_memory=False, env_vars=None, else: aws_command = 'python %s' % get_aws_cmd() full_command = '%s %s' % (aws_command, command) - stdout_encoding = _get_stdout_encoding() + stdout_encoding = get_stdout_encoding() if isinstance(full_command, six.text_type) and not six.PY3: full_command = full_command.encode(stdout_encoding) INTEG_LOG.debug("Running command: %s", full_command) @@ -429,13 +429,16 @@ def aws(command, collect_memory=False, env_vars=None, env['AWS_DEFAULT_REGION'] = "us-east-1" if env_vars is not None: env = env_vars - process = Popen(full_command, stdout=PIPE, stderr=PIPE, shell=True, - env=env) + process = Popen(full_command, stdout=PIPE, stderr=PIPE, stdin=PIPE, + shell=True, env=env) if not wait_for_finish: return process memory = None if not collect_memory: - stdout, stderr = process.communicate() + kwargs = {} + if input_data: + kwargs = {'input': input_data} + stdout, stderr = process.communicate(**kwargs) else: stdout, stderr, memory = _wait_and_collect_mem(process) return Result(process.returncode, @@ -444,7 +447,7 @@ def aws(command, collect_memory=False, env_vars=None, memory) -def _get_stdout_encoding(): +def get_stdout_encoding(): encoding = getattr(sys.__stdout__, 'encoding', None) if encoding is None: encoding = 'utf-8' diff --git a/tests/integration/customizations/s3/test_plugin.py b/tests/integration/customizations/s3/test_plugin.py index 65e9d0f422177..93882b7f5598a 100644 --- a/tests/integration/customizations/s3/test_plugin.py +++ b/tests/integration/customizations/s3/test_plugin.py @@ -28,7 +28,7 @@ import botocore.session import six -from awscli.testutils import unittest, FileCreator +from awscli.testutils import unittest, FileCreator, get_stdout_encoding from awscli.testutils import aws as _aws from tests.unit.customizations.s3 import create_bucket as _create_bucket from awscli.customizations.s3 import constants @@ -44,12 +44,13 @@ def cd(directory): os.chdir(original) -def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True): +def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True, + input_data=None): if not env_vars: env_vars = os.environ.copy() env_vars['AWS_DEFAULT_REGION'] = "us-west-2" return _aws(command, collect_memory=collect_memory, env_vars=env_vars, - wait_for_finish=wait_for_finish) + wait_for_finish=wait_for_finish, input_data=input_data) class BaseS3CLICommand(unittest.TestCase): @@ -1222,5 +1223,99 @@ def test_sync_file_with_spaces(self): self.assertEqual(p2.rc, 0) +class TestStreams(BaseS3CLICommand): + def test_upload(self): + """ + This tests uploading a small stream from stdin. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + 'This is a test') + + def test_unicode_upload(self): + """ + This tests being able to upload unicode from stdin. + """ + unicode_str = u'\u00e9 This is a test' + byte_str = unicode_str.encode('utf-8') + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=byte_str) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + unicode_str) + + def test_multipart_upload(self): + """ + This tests the ability to multipart upload streams from stdin. + The data has some unicode in it to avoid having to do a seperate + multipart upload test just for unicode. + """ + + bucket_name = self.create_bucket() + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assert_key_contents_equal(bucket_name, 'stream', data) + + def test_download(self): + """ + This tests downloading a small stream from stdout. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, 'This is a test') + + def test_unicode_download(self): + """ + This tests downloading a small unicode stream from stdout. + """ + bucket_name = self.create_bucket() + + data = u'\u00e9 This is a test' + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + + # Downloading the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + def test_multipart_download(self): + """ + This tests the ability to multipart download streams to stdout. + The data has some unicode in it to avoid having to do a seperate + multipart download test just for unicode. + """ + bucket_name = self.create_bucket() + + # First lets upload some data via streaming since + # its faster and we do not have to write to a file! + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + + # Download the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/__init__.py b/tests/unit/customizations/s3/__init__.py index 5ada082edd3e9..f8d8aa9ad916a 100644 --- a/tests/unit/customizations/s3/__init__.py +++ b/tests/unit/customizations/s3/__init__.py @@ -16,7 +16,7 @@ import string import six -from mock import patch +from mock import patch, Mock class S3HandlerBaseTest(unittest.TestCase): @@ -33,7 +33,6 @@ def setUp(self): def tearDown(self): self.wait_timeout_patch.stop() - def make_loc_files(): """ This sets up the test by making a directory named some_directory. It @@ -161,6 +160,7 @@ def compare_files(self, result_file, ref_file): self.assertEqual(result_file.src_type, ref_file.src_type) self.assertEqual(result_file.dest_type, ref_file.dest_type) self.assertEqual(result_file.operation_name, ref_file.operation_name) + self.assertEqual(result_file.is_stream, ref_file.is_stream) def list_contents(bucket, session): @@ -188,3 +188,24 @@ def list_buckets(session): html_response, response_data = operation.call(endpoint) contents = response_data['Buckets'] return contents + + +class MockStdIn(object): + """ + This class patches stdin in order to write a stream of bytes into + stdin. + """ + def __init__(self, input_bytes=b''): + input_data = six.BytesIO(input_bytes) + if six.PY2: + mock_object = input_data + else: + mock_object = Mock() + mock_object.buffer = input_data + self._patch = patch('sys.stdin', mock_object) + + def __enter__(self): + self._patch.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._patch.__exit__() diff --git a/tests/unit/customizations/s3/test_executor.py b/tests/unit/customizations/s3/test_executor.py index 9afaacd3ba22c..46eecbae6b638 100644 --- a/tests/unit/customizations/s3/test_executor.py +++ b/tests/unit/customizations/s3/test_executor.py @@ -15,6 +15,7 @@ import shutil import six from six.moves import queue +import sys import mock @@ -41,17 +42,17 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_handles_io_request(self): - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() with open(self.filename, 'rb') as f: self.assertEqual(f.read(), b'foobar') def test_out_of_order_io_requests(self): - self.queue.put(IORequest(self.filename, 6, b'morestuff')) - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 6, b'morestuff', False)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() with open(self.filename, 'rb') as f: @@ -60,10 +61,10 @@ def test_out_of_order_io_requests(self): def test_multiple_files_in_queue(self): second_file = os.path.join(self.temp_dir, 'bar') open(second_file, 'w').close() - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IORequest(second_file, 0, b'otherstuff')) - self.queue.put(IOCloseRequest(second_file)) - self.queue.put(IOCloseRequest(self.filename)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IORequest(second_file, 0, b'otherstuff', False)) + self.queue.put(IOCloseRequest(second_file, False)) + self.queue.put(IOCloseRequest(self.filename, False)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -72,6 +73,21 @@ def test_multiple_files_in_queue(self): with open(second_file, 'rb') as f: self.assertEqual(f.read(), b'otherstuff') + def test_stream_requests(self): + # Test that offset has no affect on the order in which requests + # are written to stdout. The order of requests for a stream are + # first in first out. + self.queue.put(IORequest('nonexistant-file', 10, b'foobar', True)) + self.queue.put(IORequest('nonexistant-file', 6, b'otherstuff', True)) + # The thread should not try to close the file name because it is + # writing to stdout. If it does, the thread will fail because + # the file does not exist. + self.queue.put(IOCloseRequest('nonexistant-file', True)) + self.queue.put(ShutdownThreadRequest()) + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.io_thread.run() + self.assertEqual(mock_stdout.getvalue(), 'foobarotherstuff') + class TestExecutor(unittest.TestCase): def test_shutdown_does_not_hang(self): @@ -84,12 +100,14 @@ class FloodIOQueueTask(object): def __call__(self): for i in range(50): - executor.write_queue.put(IORequest(f.name, 0, b'foobar')) + executor.write_queue.put(IORequest(f.name, 0, + b'foobar', False)) executor.submit(FloodIOQueueTask()) executor.initiate_shutdown() executor.wait_until_shutdown() self.assertEqual(open(f.name, 'rb').read(), b'foobar') + class TestPrintThread(unittest.TestCase): def test_print_warning(self): result_queue = queue.Queue() diff --git a/tests/unit/customizations/s3/test_filegenerator.py b/tests/unit/customizations/s3/test_filegenerator.py index d38a48424ce26..3c199c8589541 100644 --- a/tests/unit/customizations/s3/test_filegenerator.py +++ b/tests/unit/customizations/s3/test_filegenerator.py @@ -486,6 +486,24 @@ def test_normalize_sort_backslash(self): self.assertEqual(ref_names[i], names[i]) +class TestLocalStreams(unittest.TestCase): + def test_local_stream(self): + file_input = {'src': {'path': '-', 'type': 'local'}, + 'dest': {'path': 'mybucket/', 'type': 's3'}, + 'dir_op': False, 'use_src_name': True} + file_generator = FileGenerator(None, None, None, is_stream=True) + files = file_generator.call(file_input) + result_list = [] + for file_stat in files: + result_list.append(file_stat) + ref_list = [FileStat(src='-', dest='mybucket/-', compare_key='-', + size=0, last_update=None, src_type='local', + dest_type='s3', operation_name=None, + is_stream=True)] + for i in range(len(result_list)): + compare_files(self, result_list[i], ref_list[i]) + + class S3FileGeneratorTest(unittest.TestCase): def setUp(self): self.session = FakeSession() diff --git a/tests/unit/customizations/s3/test_fileinfo.py b/tests/unit/customizations/s3/test_fileinfo.py index 48a6651f42fbc..6a31e3edb1d59 100644 --- a/tests/unit/customizations/s3/test_fileinfo.py +++ b/tests/unit/customizations/s3/test_fileinfo.py @@ -21,6 +21,7 @@ from awscli.testutils import unittest from awscli.customizations.s3 import fileinfo +from awscli.customizations.s3.utils import MD5Error class TestSaveFile(unittest.TestCase): @@ -58,3 +59,16 @@ def test_makedir_other_exception(self, makedirs): fileinfo.save_file(self.filename, self.response_data, self.last_update) self.assertFalse(os.path.isfile(self.filename)) + + def test_stream_file(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + fileinfo.save_file(None, self.response_data, None, True) + self.assertEqual(mock_stdout.getvalue(), "foobar") + + def test_stream_file_md5_error(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.response_data['ETag'] = '"0"' + with self.assertRaises(MD5Error): + fileinfo.save_file(None, self.response_data, None, True) + # Make sure nothing is written to stdout. + self.assertEqual(mock_stdout.getvalue(), "") diff --git a/tests/unit/customizations/s3/test_fileinfobuilder.py b/tests/unit/customizations/s3/test_fileinfobuilder.py index 439c006ad1368..1791fd93f5717 100644 --- a/tests/unit/customizations/s3/test_fileinfobuilder.py +++ b/tests/unit/customizations/s3/test_fileinfobuilder.py @@ -26,7 +26,8 @@ def test_info_setter(self): files = [FileStat(src='src', dest='dest', compare_key='compare_key', size='size', last_update='last_update', src_type='src_type', dest_type='dest_type', - operation_name='operation_name')] + operation_name='operation_name', + is_stream='is_stream')] file_infos = info_setter.call(files) for file_info in file_infos: attributes = file_info.__dict__.keys() diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 20bc3a62a8589..6bf368d45bbbf 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -14,15 +14,19 @@ import os import random import sys -from awscli.testutils import unittest +import mock + +from awscli.testutils import unittest from awscli import EnvironmentVariables from awscli.customizations.s3.s3handler import S3Handler from awscli.customizations.s3.fileinfo import FileInfo +from awscli.customizations.s3.tasks import CreateMultipartUploadTask, \ + UploadPartTask, CreateLocalFileTask from tests.unit.customizations.s3.fake_session import FakeSession from tests.unit.customizations.s3 import make_loc_files, clean_loc_files, \ make_s3_files, s3_cleanup, create_bucket, list_contents, list_buckets, \ - S3HandlerBaseTest + S3HandlerBaseTest, MockStdIn class S3HandlerTestDeleteList(S3HandlerBaseTest): @@ -612,5 +616,164 @@ def test_bucket(self): self.assertEqual(orig_number_buckets, number_buckets) +class TestStreams(S3HandlerBaseTest): + def setUp(self): + super(TestStreams, self).setUp() + self.session = FakeSession() + self.service = self.session.get_service('s3') + self.endpoint = self.service.get_endpoint('us-east-1') + self.params = {'is_stream': True, 'region': 'us-east-1'} + stream_timeout = 'awscli.customizations.s3.constants.STREAM_INPUT_TIMEOUT' + self.stream_timeout_patch = mock.patch(stream_timeout, 0.001) + self.stream_timeout_patch.start() + + def tearDown(self): + super(TestStreams, self).tearDown() + self.stream_timeout_patch.stop() + + def test_pull_from_stream(self): + s3handler = S3Handler(self.session, self.params, chunksize=2) + input_to_stdin = b'This is a test' + size = len(input_to_stdin) + # Retrieve the entire string. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin) + # Ensure the function exits when there is nothing to read. + with MockStdIn(): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, b'') + # Ensure the function does not grab too much out of stdin. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size-2) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin[:-2]) + # Retrieve the rest of standard in. + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, input_to_stdin[-2:]) + + def test_upload_stream_not_multipart_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # No multipart upload should have been submitted. + self.assertEqual(len(submitted_tasks), 1) + self.assertEqual(submitted_tasks[0][0][0].payload.read(), + b'bar') + + def test_upload_stream_is_multipart_task(self): + s3handler = S3Handler(self.session, self.params, + multi_threshold=1) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # This should be a multipart upload so multiple tasks + # should have been submitted. + self.assertEqual(len(submitted_tasks), 4) + self.assertEqual(submitted_tasks[1][0][0]._payload.read(), + b'b') + self.assertEqual(submitted_tasks[2][0][0]._payload.read(), + b'ar') + + def test_upload_stream_with_expected_size(self): + self.params['expected_size'] = 100000 + # With this large of expected size, the chunksize of 2 will have + # to change. + s3handler = S3Handler(self.session, self.params, chunksize=2) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + with MockStdIn(b'bar'): + s3handler._enqueue_multipart_upload_tasks(fileinfo, False, b'') + submitted_tasks = s3handler.executor.submit.call_args_list + # Determine what the chunksize was changed to from one of the + # UploadPartTasks. + changed_chunk_size = submitted_tasks[1][0][0]._chunk_size + # New chunksize should have a total parts under 1000. + self.assertTrue(100000/changed_chunk_size < 1000) + + def test_upload_stream_enqueue_upload_start_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + s3handler._enqueue_upload_start_task(None, None, fileinfo, b'foo') + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertEqual(len(submitted_tasks), 2) + self.assertEqual(type(submitted_tasks[0][0][0]), + CreateMultipartUploadTask) + # Check that the initially pulled part of the stream gets submitted + # after the instantiating the CreateMultipartTask. + self.assertEqual(type(submitted_tasks[1][0][0]), + UploadPartTask) + # Check that the payload is correct + self.assertEqual(submitted_tasks[1][0][0]._payload, b'foo') + + def test_upload_stream_enqueue_upload_task(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + stdin_input = b'This is a test' + with MockStdIn(stdin_input): + num_parts = s3handler._enqueue_upload_tasks(None, 2, mock.Mock(), + fileinfo, + UploadPartTask) + submitted_tasks = s3handler.executor.submit.call_args_list + # Ensure the returned number of parts is correct. + self.assertEqual(num_parts, len(submitted_tasks) + 1) + # Ensure the number of tasks uploaded are as expected + self.assertEqual(len(submitted_tasks), 8) + index = 0 + for i in range(len(submitted_tasks)-1): + self.assertEqual(submitted_tasks[i][0][0]._payload.read(), + stdin_input[index:index+2]) + index += 2 + # Ensure that the last part is an empty string as expected. + self.assertEqual(submitted_tasks[7][0][0]._payload.read(), b'') + + def test_enqueue_upload_single_part_task_stream(self): + """ + This test ensures that a payload gets attached to a task when + it is submitted to the executor. + """ + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + mock_task_class = mock.Mock() + s3handler._enqueue_upload_single_part_task( + part_number=1, chunk_size=2, upload_context=None, + filename=None, task_class=mock_task_class, + payload=b'This is a test' + ) + args, kwargs = mock_task_class.call_args + self.assertIn('payload', kwargs.keys()) + self.assertEqual(kwargs['payload'], b'This is a test') + + def test_enqueue_range_download_tasks_stream(self): + s3handler = S3Handler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True, size=100) + s3handler._enqueue_range_download_tasks(fileinfo) + # Ensure that no request was sent to make a file locally. + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertNotEqual(type(submitted_tasks[0][0][0]), + CreateLocalFileTask) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/test_subcommands.py b/tests/unit/customizations/s3/test_subcommands.py index 363fc496889ff..92b4e8b1e18ff 100644 --- a/tests/unit/customizations/s3/test_subcommands.py +++ b/tests/unit/customizations/s3/test_subcommands.py @@ -197,7 +197,8 @@ def test_run_cp_put(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -213,7 +214,8 @@ def test_error_on_same_line_as_status(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -236,7 +238,8 @@ def test_run_cp_get(self): 'src': s3_file, 'dest': local_file, 'filters': filters, 'paths_type': 's3local', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -253,7 +256,8 @@ def test_run_cp_copy(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -270,7 +274,8 @@ def test_run_mv(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mv', params) cmd_arc.create_instructions() cmd_arc.run() @@ -287,7 +292,8 @@ def test_run_remove(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rm', params) cmd_arc.create_instructions() cmd_arc.run() @@ -308,7 +314,8 @@ def test_run_sync(self): 'src': local_dir, 'dest': s3_prefix, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'sync', params) cmd_arc.create_instructions() cmd_arc.run() @@ -324,7 +331,7 @@ def test_run_mb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mb', params) cmd_arc.create_instructions() cmd_arc.run() @@ -340,7 +347,7 @@ def test_run_rb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -357,7 +364,7 @@ def test_run_rb_nonzero_rc(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -468,6 +475,34 @@ def test_check_force(self): cmd_params.parameters['src'] = 's3://mybucket' cmd_params.check_force(None) + def test_validate_streaming_paths_upload(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['quiet']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_streaming_paths_download(self): + parameters = {'src': 'localfile', 'dest': '-'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['quiet']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_no_streaming_paths(self): + parameters = {'src': 'localfile', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertFalse(cmd_params.parameters['is_stream']) + + def test_validate_streaming_paths_error(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'sync', parameters, '') + with self.assertRaises(ValueError): + cmd_params._validate_streaming_paths() + class HelpDocTest(BaseAWSHelpOutputTest): def setUp(self): diff --git a/tests/unit/customizations/s3/test_tasks.py b/tests/unit/customizations/s3/test_tasks.py index 4451c85cb5697..eda16f7657789 100644 --- a/tests/unit/customizations/s3/test_tasks.py +++ b/tests/unit/customizations/s3/test_tasks.py @@ -22,6 +22,7 @@ from awscli.customizations.s3.tasks import CompleteDownloadTask from awscli.customizations.s3.tasks import DownloadPartTask from awscli.customizations.s3.tasks import MultipartUploadContext +from awscli.customizations.s3.tasks import MultipartDownloadContext from awscli.customizations.s3.tasks import UploadCancelledError from awscli.customizations.s3.tasks import print_operation from awscli.customizations.s3.tasks import RetriesExeededError @@ -163,6 +164,58 @@ def test_basic_threaded_parts(self): self.calls[2][1:], ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_streaming_threaded_parts(self): + # This is similar to the basic threaded parts test but instead + # the thread has to wait to know exactly how many parts are + # expected from the stream. This is indicated when the expected + # parts of the context changes from ... to an integer. + + self.context = MultipartUploadContext(expected_parts='...') + upload_part_thread = threading.Thread(target=self.upload_part, + args=(1,)) + # Once this thread starts it will immediately block. + self.start_thread(upload_part_thread) + + # Also, let's start the thread that will do the complete + # multipart upload. It will also block because it needs all + # the parts so it's blocked up the upload_part_thread. It also + # needs the upload_id so it's blocked on that as well. + complete_upload_thread = threading.Thread(target=self.complete_upload) + self.start_thread(complete_upload_thread) + + # Then finally the CreateMultipartUpload completes and we + # announce the upload id. + self.create_upload('my_upload_id') + # The complete upload thread should still be waiting for an expect + # parts number. + with self.call_lock: + was_completed = (len(self.calls) > 2) + + # The upload_part thread can now proceed as well as the complete + # multipart upload thread. + self.context.announce_total_parts(1) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # Make sure that the completed task was never called since it was + # waiting to announce the parts. + self.assertFalse(was_completed) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 3) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'create_multipart_upload') + self.assertEqual(self.calls[1][0], 'upload_part') + self.assertEqual(self.calls[2][0], 'complete_upload') + + # Verify the correct args were used. + self.assertEqual(self.calls[0][1], 'my_upload_id') + self.assertEqual(self.calls[1][1:], (1, 'my_upload_id')) + self.assertEqual( + self.calls[2][1:], + ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_randomized_stress_test(self): # Now given that we've verified the functionality from # the two tests above, we randomize the threading to ensure @@ -279,6 +332,7 @@ def setUp(self): self.filename.size = 10 * 1024 * 1024 self.filename.src = 'bucket/key' self.filename.dest = 'local/file' + self.filename.is_stream = False self.filename.service = self.service self.filename.operation_name = 'download' self.context = mock.Mock() @@ -325,9 +379,9 @@ def test_download_queues_io_properly(self): call_args_list = self.io_queue.put.call_args_list self.assertEqual(len(call_args_list), 2) self.assertEqual(call_args_list[0], - mock.call(('local/file', 0, b'foobar'))) + mock.call(('local/file', 0, b'foobar', False))) self.assertEqual(call_args_list[1], - mock.call(('local/file', 6, b'morefoobar'))) + mock.call(('local/file', 6, b'morefoobar', False))) def test_incomplete_read_is_retried(self): self.service.get_operation.return_value.call.side_effect = \ @@ -342,6 +396,61 @@ def test_incomplete_read_is_retried(self): self.service.get_operation.call_count) +class TestMultipartDownloadContext(unittest.TestCase): + def setUp(self): + self.context = MultipartDownloadContext(num_parts=2) + self.calls = [] + self.threads = [] + self.call_lock = threading.Lock() + self.caught_exception = None + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + + def download_stream_part(self, part_number): + try: + self.context.wait_for_turn(part_number) + with self.call_lock: + self.calls.append(('download_part', str(part_number))) + self.context.done_with_turn() + except Exception as e: + self.caught_exception = e + return + + def start_thread(self, thread): + thread.start() + self.threads.append(thread) + + def test_stream_context(self): + part_thread = threading.Thread(target=self.download_stream_part, + args=(1,)) + # Once this thread starts it will immediately block becasue it is + # waiting for part zero to finish submitting its task. + self.start_thread(part_thread) + + # Now create the thread that should submit its task first. + part_thread2 = threading.Thread(target=self.download_stream_part, + args=(0,)) + self.start_thread(part_thread2) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 2) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'download_part') + self.assertEqual(self.calls[1][0], 'download_part') + + # Verify the correct order were used. + self.assertEqual(self.calls[0][1], '0') + self.assertEqual(self.calls[1][1], '1') + + class TestTaskOrdering(unittest.TestCase): def setUp(self): self.q = StablePriorityQueue(maxsize=10, max_priority=20) diff --git a/tests/unit/test_completer.py b/tests/unit/test_completer.py index fc5365b40a69f..aab65f01a5d92 100644 --- a/tests/unit/test_completer.py +++ b/tests/unit/test_completer.py @@ -73,7 +73,8 @@ '--cache-control', '--content-type', '--content-disposition', '--source-region', '--content-encoding', '--content-language', - '--expires', '--grants'] + GLOBALOPTS)), + '--expires', '--grants', '--expected-size'] + + GLOBALOPTS)), ('aws s3 cp --quiet -', -1, set(['--no-guess-mime-type', '--dryrun', '--recursive', '--content-type', '--follow-symlinks', '--no-follow-symlinks', @@ -82,7 +83,7 @@ '--expires', '--website-redirect', '--acl', '--storage-class', '--sse', '--exclude', '--include', - '--source-region', + '--source-region','--expected-size', '--grants'] + GLOBALOPTS)), ('aws emr ', -1, set(['add-instance-groups', 'add-steps', 'add-tags', 'create-cluster', 'create-default-roles',