diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbb34bafa910..474464ed743a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,10 @@ Next Release (TBD) (`issue 919 `__) * feature:``aws s3``: Add ``--only-show-errors`` option that displays errors and warnings but suppresses all other output. +* feature:``aws s3 cp``: Added ability to upload local + file streams from standard input to s3 and download s3 + objects as local file streams to standard output. + (`issue 903 `__) 1.4.4 diff --git a/awscli/customizations/s3/executor.py b/awscli/customizations/s3/executor.py index cb31d709cb95..d2f2c9b89152 100644 --- a/awscli/customizations/s3/executor.py +++ b/awscli/customizations/s3/executor.py @@ -15,8 +15,8 @@ import sys import threading -from awscli.customizations.s3.utils import uni_print, \ - IORequest, IOCloseRequest, StablePriorityQueue +from awscli.customizations.s3.utils import uni_print, bytes_print, \ + IORequest, IOCloseRequest, StablePriorityQueue from awscli.customizations.s3.tasks import OrderableTask @@ -154,15 +154,19 @@ def run(self): self._cleanup() return elif isinstance(task, IORequest): - filename, offset, data = task - fileobj = self.fd_descriptor_cache.get(filename) - if fileobj is None: - fileobj = open(filename, 'rb+') - self.fd_descriptor_cache[filename] = fileobj - fileobj.seek(offset) + filename, offset, data, is_stream = task + if is_stream: + fileobj = sys.stdout + bytes_print(data) + else: + fileobj = self.fd_descriptor_cache.get(filename) + if fileobj is None: + fileobj = open(filename, 'rb+') + self.fd_descriptor_cache[filename] = fileobj + fileobj.seek(offset) + fileobj.write(data) LOGGER.debug("Writing data to: %s, offset: %s", filename, offset) - fileobj.write(data) fileobj.flush() elif isinstance(task, IOCloseRequest): LOGGER.debug("IOCloseRequest received for %s, closing file.", @@ -239,7 +243,7 @@ def __init__(self, result_queue, quiet, only_show_errors): self._lock = threading.Lock() self._needs_newline = False - self._total_parts = 0 + self._total_parts = '...' self._total_files = '...' # This is a public attribute that clients can inspect to determine diff --git a/awscli/customizations/s3/filegenerator.py b/awscli/customizations/s3/filegenerator.py index b53be0c45939..0ee72402d9c7 100644 --- a/awscli/customizations/s3/filegenerator.py +++ b/awscli/customizations/s3/filegenerator.py @@ -20,7 +20,8 @@ from dateutil.tz import tzlocal from awscli.customizations.s3.utils import find_bucket_key, get_file_stat -from awscli.customizations.s3.utils import BucketLister, create_warning +from awscli.customizations.s3.utils import BucketLister, create_warning, \ + find_dest_path_comp_key from awscli.errorhandler import ClientError @@ -131,26 +132,13 @@ def call(self, files): ``dir_op`` and ``use_src_name`` flags affect which files are used and ensure the proper destination paths and compare keys are formed. """ - src = files['src'] - dest = files['dest'] - src_type = src['type'] - dest_type = dest['type'] function_table = {'s3': self.list_objects, 'local': self.list_files} - sep_table = {'s3': '/', 'local': os.sep} - source = src['path'] + source = files['src']['path'] + src_type = files['src']['type'] + dest_type = files['dest']['type'] file_list = function_table[src_type](source, files['dir_op']) for src_path, size, last_update in file_list: - if files['dir_op']: - rel_path = src_path[len(src['path']):] - else: - rel_path = src_path.split(sep_table[src_type])[-1] - compare_key = rel_path.replace(sep_table[src_type], '/') - if files['use_src_name']: - dest_path = dest['path'] - dest_path += rel_path.replace(sep_table[src_type], - sep_table[dest_type]) - else: - dest_path = dest['path'] + dest_path, compare_key = find_dest_path_comp_key(files, src_path) yield FileStat(src=src_path, dest=dest_path, compare_key=compare_key, size=size, last_update=last_update, src_type=src_type, diff --git a/awscli/customizations/s3/fileinfo.py b/awscli/customizations/s3/fileinfo.py index fe482e64d13b..b30c67bcdc57 100644 --- a/awscli/customizations/s3/fileinfo.py +++ b/awscli/customizations/s3/fileinfo.py @@ -11,7 +11,7 @@ from botocore.compat import quote from awscli.customizations.s3.utils import find_bucket_key, \ check_etag, check_error, operate, uni_print, \ - guess_content_type, MD5Error + guess_content_type, MD5Error, bytes_print class CreateDirectoryError(Exception): @@ -26,7 +26,7 @@ def read_file(filename): return in_file.read() -def save_file(filename, response_data, last_update): +def save_file(filename, response_data, last_update, is_stream=False): """ This writes to the file upon downloading. It reads the data in the response. Makes a new directory if needed and then writes the @@ -35,31 +35,57 @@ def save_file(filename, response_data, last_update): """ body = response_data['Body'] etag = response_data['ETag'][1:-1] - d = os.path.dirname(filename) - try: - if not os.path.exists(d): - os.makedirs(d) - except OSError as e: - if not e.errno == errno.EEXIST: - raise CreateDirectoryError( - "Could not create directory %s: %s" % (d, e)) + if not is_stream: + d = os.path.dirname(filename) + try: + if not os.path.exists(d): + os.makedirs(d) + except OSError as e: + if not e.errno == errno.EEXIST: + raise CreateDirectoryError( + "Could not create directory %s: %s" % (d, e)) md5 = hashlib.md5() file_chunks = iter(partial(body.read, 1024 * 1024), b'') - with open(filename, 'wb') as out_file: - if not _is_multipart_etag(etag): - for chunk in file_chunks: - md5.update(chunk) - out_file.write(chunk) - else: - for chunk in file_chunks: - out_file.write(chunk) + if is_stream: + # Need to save the data to be able to check the etag for a stream + # becuase once the data is written to the stream there is no + # undoing it. + payload = write_to_file(None, etag, md5, file_chunks, True) + else: + with open(filename, 'wb') as out_file: + write_to_file(out_file, etag, md5, file_chunks) + if not _is_multipart_etag(etag): if etag != md5.hexdigest(): - os.remove(filename) + if not is_stream: + os.remove(filename) raise MD5Error(filename) - last_update_tuple = last_update.timetuple() - mod_timestamp = time.mktime(last_update_tuple) - os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + + if not is_stream: + last_update_tuple = last_update.timetuple() + mod_timestamp = time.mktime(last_update_tuple) + os.utime(filename, (int(mod_timestamp), int(mod_timestamp))) + else: + # Now write the output to stdout since the md5 is correct. + bytes_print(payload) + sys.stdout.flush() + + +def write_to_file(out_file, etag, md5, file_chunks, is_stream=False): + """ + Updates the etag for each file chunk. It will write to the file if it a + file but if it is a stream it will return a byte string to be later + written to a stream. + """ + body = b'' + for chunk in file_chunks: + if not _is_multipart_etag(etag): + md5.update(chunk) + if is_stream: + body += chunk + else: + out_file.write(chunk) + return body def _is_multipart_etag(etag): @@ -140,7 +166,7 @@ class FileInfo(TaskInfo): def __init__(self, src, dest=None, compare_key=None, size=None, last_update=None, src_type=None, dest_type=None, operation_name=None, service=None, endpoint=None, - parameters=None, source_endpoint=None): + parameters=None, source_endpoint=None, is_stream=False): super(FileInfo, self).__init__(src, src_type=src_type, operation_name=operation_name, service=service, @@ -157,6 +183,18 @@ def __init__(self, src, dest=None, compare_key=None, size=None, self.parameters = {'acl': None, 'sse': None} self.source_endpoint = source_endpoint + self.is_stream = is_stream + + def set_size_from_s3(self): + """ + This runs a ``HeadObject`` on the s3 object and sets the size. + """ + bucket, key = find_bucket_key(self.src) + params = {'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key} + response_data, http = operate(self.service, 'HeadObject', params) + self.size = int(response_data['ContentLength']) def _permission_to_param(self, permission): if permission == 'read': @@ -204,24 +242,30 @@ def _handle_object_params(self, params): if self.parameters['expires']: params['expires'] = self.parameters['expires'][0] - def upload(self): + def upload(self, payload=None): """ Redirects the file to the multipart upload function if the file is large. If it is small enough, it puts the file as an object in s3. """ - with open(self.src, 'rb') as body: - bucket, key = find_bucket_key(self.dest) - params = { - 'endpoint': self.endpoint, - 'bucket': bucket, - 'key': key, - 'body': body, - } - self._handle_object_params(params) - response_data, http = operate(self.service, 'PutObject', params) - etag = response_data['ETag'][1:-1] - body.seek(0) - check_etag(etag, body) + if payload: + self._handle_upload(payload) + else: + with open(self.src, 'rb') as body: + self._handle_upload(body) + + def _handle_upload(self, body): + bucket, key = find_bucket_key(self.dest) + params = { + 'endpoint': self.endpoint, + 'bucket': bucket, + 'key': key, + 'body': body, + } + self._handle_object_params(params) + response_data, http = operate(self.service, 'PutObject', params) + etag = response_data['ETag'][1:-1] + body.seek(0) + check_etag(etag, body) def _inject_content_type(self, params, filename): # Add a content type param if we can guess the type. @@ -237,7 +281,8 @@ def download(self): bucket, key = find_bucket_key(self.src) params = {'endpoint': self.endpoint, 'bucket': bucket, 'key': key} response_data, http = operate(self.service, 'GetObject', params) - save_file(self.dest, response_data, self.last_update) + save_file(self.dest, response_data, self.last_update, + self.is_stream) def copy(self): """ diff --git a/awscli/customizations/s3/fileinfobuilder.py b/awscli/customizations/s3/fileinfobuilder.py index 8bc2042615ef..9f1c429f0fc0 100644 --- a/awscli/customizations/s3/fileinfobuilder.py +++ b/awscli/customizations/s3/fileinfobuilder.py @@ -19,13 +19,14 @@ class FileInfoBuilder(object): a ``FileInfo`` object so that the operation can be performed. """ def __init__(self, service, endpoint, source_endpoint=None, - parameters = None): + parameters = None, is_stream=False): self._service = service self._endpoint = endpoint self._source_endpoint = endpoint if source_endpoint: self._source_endpoint = source_endpoint - self._parameters = parameters + self._parameters = parameters + self._is_stream = is_stream def call(self, files): for file_base in files: @@ -46,4 +47,5 @@ def _inject_info(self, file_base): file_info_attr['endpoint'] = self._endpoint file_info_attr['source_endpoint'] = self._source_endpoint file_info_attr['parameters'] = self._parameters + file_info_attr['is_stream'] = self._is_stream return FileInfo(**file_info_attr) diff --git a/awscli/customizations/s3/s3handler.py b/awscli/customizations/s3/s3handler.py index cbdf1b4b4aeb..dc226eafcd6d 100644 --- a/awscli/customizations/s3/s3handler.py +++ b/awscli/customizations/s3/s3handler.py @@ -14,7 +14,9 @@ import logging import math import os +import six from six.moves import queue +import sys from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \ NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE @@ -36,6 +38,8 @@ class S3Handler(object): class pull tasks from to complete. """ MAX_IO_QUEUE_SIZE = 20 + MAX_EXECUTOR_QUEUE_SIZE = MAX_QUEUE_SIZE + EXECUTOR_NUM_THREADS = NUM_THREADS def __init__(self, session, params, result_queue=None, multi_threshold=MULTI_THRESHOLD, chunksize=CHUNKSIZE): @@ -53,7 +57,9 @@ def __init__(self, session, params, result_queue=None, 'content_type': None, 'cache_control': None, 'content_disposition': None, 'content_encoding': None, 'content_language': None, 'expires': None, - 'grants': None, 'only_show_errors': False} + 'grants': None, 'only_show_errors': False, + 'is_stream': False, 'paths_type': None, + 'expected_size': None} self.params['region'] = params['region'] for key in self.params.keys(): if key in params: @@ -61,7 +67,8 @@ def __init__(self, session, params, result_queue=None, self.multi_threshold = multi_threshold self.chunksize = chunksize self.executor = Executor( - num_threads=NUM_THREADS, result_queue=self.result_queue, + num_threads=self.EXECUTOR_NUM_THREADS, + result_queue=self.result_queue, quiet=self.params['quiet'], only_show_errors=self.params['only_show_errors'], max_queue_size=MAX_QUEUE_SIZE, write_queue=self.write_queue @@ -235,12 +242,11 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): create_file_task = tasks.CreateLocalFileTask(context=context, filename=filename) self.executor.submit(create_file_task) - for i in range(num_downloads): - task = tasks.DownloadPartTask( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, service=filename.service, - filename=filename, context=context, io_queue=self.write_queue) - self.executor.submit(task) + self._do_enqueue_range_download_tasks( + filename=filename, chunksize=chunksize, + num_downloads=num_downloads, context=context, + remove_remote_file=remove_remote_file + ) complete_file_task = tasks.CompleteDownloadTask( context=context, filename=filename, result_queue=self.result_queue, params=self.params, io_queue=self.write_queue) @@ -252,6 +258,16 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): self.executor.submit(remove_task) return num_downloads + def _do_enqueue_range_download_tasks(self, filename, chunksize, + num_downloads, context, + remove_remote_file=False): + for i in range(num_downloads): + task = tasks.DownloadPartTask( + part_number=i, chunk_size=chunksize, + result_queue=self.result_queue, service=filename.service, + filename=filename, context=context, io_queue=self.write_queue) + self.executor.submit(task) + def _enqueue_multipart_upload_tasks(self, filename, remove_local_file=False): # First we need to create a CreateMultipartUpload task, @@ -296,14 +312,27 @@ def _enqueue_upload_start_task(self, chunksize, num_uploads, filename): self.executor.submit(create_multipart_upload_task) return upload_context - def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, filename, - task_class): + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): for i in range(1, (num_uploads + 1)): - task = task_class( - part_number=i, chunk_size=chunksize, - result_queue=self.result_queue, upload_context=upload_context, - filename=filename) - self.executor.submit(task) + self._enqueue_upload_single_part_task( + part_number=i, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class + ) + + def _enqueue_upload_single_part_task(self, part_number, chunk_size, + upload_context, filename, task_class, + payload=None): + kwargs = {'part_number': part_number, 'chunk_size': chunk_size, + 'result_queue': self.result_queue, + 'upload_context': upload_context, 'filename': filename} + if payload: + kwargs['payload'] = payload + task = task_class(**kwargs) + self.executor.submit(task) def _enqueue_upload_end_task(self, filename, upload_context): complete_multipart_upload_task = tasks.CompleteMultipartUploadTask( @@ -312,3 +341,157 @@ def _enqueue_upload_end_task(self, filename, upload_context): self.executor.submit(complete_multipart_upload_task) self._multipart_uploads.append((upload_context, filename)) + +class S3StreamHandler(S3Handler): + """ + This class is an alternative ``S3Handler`` to be used when the operation + involves a stream since the logic is different when uploading and + downloading streams. + """ + + # This ensures that the number of multipart chunks waiting in the + # executor queue and in the threads is limited. + MAX_EXECUTOR_QUEUE_SIZE = 2 + EXECUTOR_NUM_THREADS = 6 + + def _enqueue_tasks(self, files): + total_files = 0 + total_parts = 0 + for filename in files: + num_uploads = 1 + # If uploading stream, it is required to read from the stream + # to determine if the stream needs to be multipart uploaded. + payload = None + if filename.operation_name == 'upload': + payload, is_multipart_task = \ + self._pull_from_stream(self.multi_threshold) + else: + # Set the file size for the ``FileInfo`` object since + # streams do not use a ``FileGenerator`` that usually + # determines the size. + filename.set_size_from_s3() + is_multipart_task = self._is_multipart_task(filename) + if is_multipart_task and not self.params['dryrun']: + # If we're in dryrun mode, then we don't need the + # real multipart tasks. We can just use a BasicTask + # in the else clause below, which will print out the + # fact that it's transferring a file rather than + # the specific part tasks required to perform the + # transfer. + num_uploads = self._enqueue_multipart_tasks(filename, payload) + else: + task = tasks.BasicTask( + session=self.session, filename=filename, + parameters=self.params, + result_queue=self.result_queue, + payload=payload) + self.executor.submit(task) + total_files += 1 + total_parts += num_uploads + return total_files, total_parts + + def _pull_from_stream(self, amount_requested): + """ + This function pulls data from stdin until it hits the amount + requested or there is no more left to pull in from stdin. The + function wraps the data into a ``BytesIO`` object that is returned + along with a boolean telling whether the amount requested is + the amount returned. + """ + stream_filein = sys.stdin + if six.PY3: + stream_filein = sys.stdin.buffer + payload = stream_filein.read(amount_requested) + payload_file = six.BytesIO(payload) + return payload_file, len(payload) == amount_requested + + def _enqueue_multipart_tasks(self, filename, payload=None): + num_uploads = 1 + if filename.operation_name == 'upload': + num_uploads = self._enqueue_multipart_upload_tasks(filename, + payload=payload) + elif filename.operation_name == 'download': + num_uploads = self._enqueue_range_download_tasks(filename) + return num_uploads + + def _enqueue_range_download_tasks(self, filename, remove_remote_file=False): + + # Create the context for the multipart download. + chunksize = find_chunksize(filename.size, self.chunksize) + num_downloads = int(filename.size / chunksize) + context = tasks.MultipartDownloadContext(num_downloads) + + # No file is needed for downloading a stream. So just announce + # that it has been made since it is required for the context to + # begin downloading. + context.announce_file_created() + + # Submit download part tasks to the executor. + self._do_enqueue_range_download_tasks( + filename=filename, chunksize=chunksize, + num_downloads=num_downloads, context=context, + remove_remote_file=remove_remote_file + ) + return num_downloads + + def _enqueue_multipart_upload_tasks(self, filename, payload=None): + # First we need to create a CreateMultipartUpload task, + # then create UploadTask objects for each of the parts. + # And finally enqueue a CompleteMultipartUploadTask. + + chunksize = self.chunksize + # Determine an appropriate chunksize if given an expected size. + if self.params['expected_size']: + chunksize = find_chunksize(int(self.params['expected_size']), + self.chunksize) + num_uploads = '...' + + # Submit a task to begin the multipart upload. + upload_context = self._enqueue_upload_start_task( + chunksize, num_uploads, filename) + + # Now submit a task to upload the initial chunk of data pulled + # from the stream that was used to determine if a multipart upload + # was needed. + self._enqueue_upload_single_part_task( + part_number=1, chunk_size=chunksize, + upload_context=upload_context, filename=filename, + task_class=tasks.UploadPartTask, payload=payload + ) + + # Submit tasks to upload the rest of the chunks of the data coming in + # from standard input. + num_uploads = self._enqueue_upload_tasks( + num_uploads, chunksize, upload_context, + filename, tasks.UploadPartTask + ) + + # Submit a task to notify the multipart upload is complete. + self._enqueue_upload_end_task(filename, upload_context) + + return num_uploads + + def _enqueue_upload_tasks(self, num_uploads, chunksize, upload_context, + filename, task_class): + # The previous upload occured right after the multipart + # upload started for a stream. + num_uploads = 1 + while True: + # Pull more data from standard input. + payload, is_remaining = self._pull_from_stream(chunksize) + # Submit an upload part task for the recently pulled data. + self._enqueue_upload_single_part_task( + part_number=num_uploads+1, + chunk_size=chunksize, + upload_context=upload_context, + filename=filename, + task_class=task_class, + payload=payload + ) + num_uploads += 1 + if not is_remaining: + break + # Once there is no more data left, announce to the context how + # many parts are being uploaded so it knows when it can quit. + upload_context.announce_total_parts(num_uploads) + return num_uploads diff --git a/awscli/customizations/s3/subcommands.py b/awscli/customizations/s3/subcommands.py index 096e2a102e9a..0ce4613fb7fb 100644 --- a/awscli/customizations/s3/subcommands.py +++ b/awscli/customizations/s3/subcommands.py @@ -23,11 +23,11 @@ from awscli.customizations.s3.fileinfobuilder import FileInfoBuilder from awscli.customizations.s3.fileformat import FileFormat from awscli.customizations.s3.filegenerator import FileGenerator -from awscli.customizations.s3.fileinfo import TaskInfo +from awscli.customizations.s3.fileinfo import TaskInfo, FileInfo from awscli.customizations.s3.filters import create_filter -from awscli.customizations.s3.s3handler import S3Handler +from awscli.customizations.s3.s3handler import S3Handler, S3StreamHandler from awscli.customizations.s3.utils import find_bucket_key, uni_print, \ - AppendFilter + AppendFilter, find_dest_path_comp_key RECURSIVE = {'name': 'recursive', 'action': 'store_true', 'dest': 'dir_op', @@ -211,6 +211,15 @@ 'Only errors and warnings are displayed. All other ' 'output is suppressed.')} +EXPECTED_SIZE = {'name': 'expected-size', + 'help_text': ( + 'This argument specifies the expected size of a stream ' + 'in terms of bytes. Note that this argument is needed ' + 'only when a stream is being uploaded to s3 and the size ' + 'is larger than 5GB. Failure to include this argument ' + 'under these conditions may result in a failed upload. ' + 'due to too many parts in upload.')} + TRANSFER_ARGS = [DRYRUN, QUIET, RECURSIVE, INCLUDE, EXCLUDE, ACL, FOLLOW_SYMLINKS, NO_FOLLOW_SYMLINKS, NO_GUESS_MIME_TYPE, SSE, STORAGE_CLASS, GRANTS, WEBSITE_REDIRECT, CONTENT_TYPE, @@ -415,7 +424,7 @@ class CpCommand(S3TransferCommand): USAGE = " or " \ "or " ARG_TABLE = [{'name': 'paths', 'nargs': 2, 'positional_arg': True, - 'synopsis': USAGE}] + TRANSFER_ARGS + 'synopsis': USAGE}] + TRANSFER_ARGS + [EXPECTED_SIZE] EXAMPLES = BasicCommand.FROM_FILE('s3/cp.rst') @@ -512,16 +521,21 @@ def create_instructions(self): instruction list because it sends the request to S3 and does not yield anything. """ - if self.cmd not in ['mb', 'rb']: + if self.needs_filegenerator(): self.instructions.append('file_generator') - if self.parameters.get('filters'): - self.instructions.append('filters') - if self.cmd == 'sync': - self.instructions.append('comparator') - if self.cmd not in ['mb', 'rb']: + if self.parameters.get('filters'): + self.instructions.append('filters') + if self.cmd == 'sync': + self.instructions.append('comparator') self.instructions.append('file_info_builder') self.instructions.append('s3_handler') + def needs_filegenerator(self): + if self.cmd in ['mb', 'rb'] or self.parameters['is_stream']: + return False + else: + return True + def run(self): """ This function wires together all of the generators and completes @@ -578,10 +592,22 @@ def run(self): operation_name=operation_name, service=self._service, endpoint=self._endpoint)] + stream_dest_path, stream_compare_key = find_dest_path_comp_key(files) + stream_file_info = [FileInfo(src=files['src']['path'], + dest=stream_dest_path, + compare_key=stream_compare_key, + src_type=files['src']['type'], + dest_type=files['dest']['type'], + operation_name=operation_name, + service=self._service, + endpoint=self._endpoint, + is_stream=True)] file_info_builder = FileInfoBuilder(self._service, self._endpoint, self._source_endpoint, self.parameters) s3handler = S3Handler(self.session, self.parameters, result_queue=result_queue) + s3_stream_handler = S3StreamHandler(self.session, self.parameters, + result_queue=result_queue) command_dict = {} if self.cmd == 'sync': @@ -593,6 +619,9 @@ def run(self): 'comparator': [Comparator(self.parameters)], 'file_info_builder': [file_info_builder], 's3_handler': [s3handler]} + elif self.cmd == 'cp' and self.parameters['is_stream']: + command_dict = {'setup': [stream_file_info], + 's3_handler': [s3_stream_handler]} elif self.cmd == 'cp': command_dict = {'setup': [files], 'file_generator': [file_generator], @@ -685,8 +714,19 @@ def add_paths(self, paths): self.parameters['dest'] = paths[1] elif len(paths) == 1: self.parameters['dest'] = paths[0] + self._validate_streaming_paths() self._validate_path_args() + def _validate_streaming_paths(self): + self.parameters['is_stream'] = False + if self.parameters['src'] == '-' or self.parameters['dest'] == '-': + self.parameters['is_stream'] = True + self.parameters['dir_op'] = False + self.parameters['only_show_errors'] = True + if self.parameters['is_stream'] and self.cmd != 'cp': + raise ValueError("Streaming currently is only compatible with " + "single file cp commands") + def _validate_path_args(self): # If we're using a mv command, you can't copy the object onto itself. params = self.parameters diff --git a/awscli/customizations/s3/tasks.py b/awscli/customizations/s3/tasks.py index 37089c42a2b9..be326f35e56d 100644 --- a/awscli/customizations/s3/tasks.py +++ b/awscli/customizations/s3/tasks.py @@ -63,7 +63,8 @@ class BasicTask(OrderableTask): attributes like ``session`` object in order for the filename to perform its designated operation. """ - def __init__(self, session, filename, parameters, result_queue): + def __init__(self, session, filename, parameters, + result_queue, payload=None): self.session = session self.service = self.session.get_service('s3') @@ -72,6 +73,7 @@ def __init__(self, session, filename, parameters, result_queue): self.parameters = parameters self.result_queue = result_queue + self.payload = payload def __call__(self): self._execute_task(attempts=3) @@ -84,9 +86,12 @@ def _execute_task(self, attempts, last_error=''): error_message=last_error) return filename = self.filename + kwargs = {} + if self.payload: + kwargs['payload'] = self.payload try: if not self.parameters['dryrun']: - getattr(filename, filename.operation_name)() + getattr(filename, filename.operation_name)(**kwargs) except requests.ConnectionError as e: connect_error = str(e) LOGGER.debug("%s %s failure: %s", @@ -195,13 +200,14 @@ class UploadPartTask(OrderableTask): complete the multipart upload initiated by the ``FileInfo`` object. """ - def __init__(self, part_number, chunk_size, - result_queue, upload_context, filename): + def __init__(self, part_number, chunk_size, result_queue, upload_context, + filename, payload=None): self._result_queue = result_queue self._upload_context = upload_context self._part_number = part_number self._chunk_size = chunk_size self._filename = filename + self._payload = payload def _read_part(self): actual_filename = self._filename.src @@ -216,9 +222,13 @@ def __call__(self): LOGGER.debug("Waiting for upload id.") upload_id = self._upload_context.wait_for_upload_id() bucket, key = find_bucket_key(self._filename.dest) - total = int(math.ceil( - self._filename.size/float(self._chunk_size))) - body = self._read_part() + if self._filename.is_stream: + body = self._payload + total = self._upload_context.expected_parts + else: + total = int(math.ceil( + self._filename.size/float(self._chunk_size))) + body = self._read_part() params = {'endpoint': self._filename.endpoint, 'bucket': bucket, 'key': key, 'part_number': self._part_number, @@ -393,16 +403,23 @@ def _queue_writes(self, body): body.set_socket_timeout(self.READ_TIMEOUT) amount_read = 0 current = body.read(iterate_chunk_size) + if self._filename.is_stream: + self._context.wait_for_turn(self._part_number) while current: offset = self._part_number * self._chunk_size + amount_read LOGGER.debug("Submitting IORequest to write queue.") - self._io_queue.put(IORequest(self._filename.dest, offset, current)) + self._io_queue.put( + IORequest(self._filename.dest, offset, current, + self._filename.is_stream) + ) LOGGER.debug("Request successfully submitted.") amount_read += len(current) current = body.read(iterate_chunk_size) # Change log message. LOGGER.debug("Done queueing writes for part number %s to file: %s", self._part_number, self._filename.dest) + if self._filename.is_stream: + self._context.done_with_turn() class CreateMultipartUploadTask(BasicTask): @@ -530,7 +547,7 @@ class MultipartUploadContext(object): _CANCELLED = '_CANCELLED' _COMPLETED = '_COMPLETED' - def __init__(self, expected_parts): + def __init__(self, expected_parts='...'): self._upload_id = None self._expected_parts = expected_parts self._parts = [] @@ -540,6 +557,10 @@ def __init__(self, expected_parts): self._upload_complete_condition = threading.Condition(self._lock) self._state = self._UNSTARTED + @property + def expected_parts(self): + return self._expected_parts + def announce_upload_id(self, upload_id): with self._upload_id_condition: self._upload_id = upload_id @@ -551,9 +572,15 @@ def announce_finished_part(self, etag, part_number): self._parts.append({'ETag': etag, 'PartNumber': part_number}) self._parts_condition.notifyAll() + def announce_total_parts(self, total_parts): + with self._parts_condition: + self._expected_parts = total_parts + self._parts_condition.notifyAll() + def wait_for_parts_to_finish(self): with self._parts_condition: - while len(self._parts) < self._expected_parts: + while self._expected_parts == '...' or \ + len(self._parts) < self._expected_parts: if self._state == self._CANCELLED: raise UploadCancelledError("Upload has been cancelled.") self._parts_condition.wait(timeout=1) @@ -653,9 +680,11 @@ def __init__(self, num_parts, lock=None): lock = threading.Lock() self._lock = lock self._created_condition = threading.Condition(self._lock) + self._submit_write_condition = threading.Condition(self._lock) self._completed_condition = threading.Condition(self._lock) self._state = self._STATES['UNSTARTED'] self._finished_parts = set() + self._current_stream_part_number = 0 def announce_completed_part(self, part_number): with self._completed_condition: @@ -685,6 +714,19 @@ def wait_for_completion(self): "Download has been cancelled.") self._completed_condition.wait(timeout=1) + def wait_for_turn(self, part_number): + with self._submit_write_condition: + while self._current_stream_part_number != part_number: + if self._state == self._STATES['CANCELLED']: + raise DownloadCancelledError( + "Download has been cancelled.") + self._submit_write_condition.wait(timeout=0.2) + + def done_with_turn(self): + with self._submit_write_condition: + self._current_stream_part_number += 1 + self._submit_write_condition.notifyAll() + def cancel(self): with self._lock: self._state = self._STATES['CANCELLED'] diff --git a/awscli/customizations/s3/utils.py b/awscli/customizations/s3/utils.py index 97b17933d9a8..9a0569ae66ce 100644 --- a/awscli/customizations/s3/utils.py +++ b/awscli/customizations/s3/utils.py @@ -145,6 +145,34 @@ def get_file_stat(path): return stats.st_size, update_time +def find_dest_path_comp_key(files, src_path=None): + """ + This is a helper function that determines the destination path and compare + key given parameters received from the ``FileFormat`` class. + """ + src = files['src'] + dest = files['dest'] + src_type = src['type'] + dest_type = dest['type'] + if src_path is None: + src_path = src['path'] + + sep_table = {'s3': '/', 'local': os.sep} + + if files['dir_op']: + rel_path = src_path[len(src['path']):] + else: + rel_path = src_path.split(sep_table[src_type])[-1] + compare_key = rel_path.replace(sep_table[src_type], '/') + if files['use_src_name']: + dest_path = dest['path'] + dest_path += rel_path.replace(sep_table[src_type], + sep_table[dest_type]) + else: + dest_path = dest['path'] + return dest_path, compare_key + + def check_etag(etag, fileobj): """ This fucntion checks the etag and the md5 checksum to ensure no @@ -246,6 +274,21 @@ def uni_print(statement, out_file=None): out_file.flush() +def bytes_print(statement): + """ + This function is used to properly write bytes to standard out. + """ + if PY3: + if getattr(sys.stdout, 'buffer', None): + sys.stdout.buffer.write(statement) + else: + # If it is not possible to write to the standard out buffer. + # The next best option is to decode and write to standard out. + sys.stdout.write(statement.decode('utf-8')) + else: + sys.stdout.write(statement) + + def guess_content_type(filename): """Given a filename, guess it's content type. @@ -404,7 +447,8 @@ def __new__(cls, message, error=False, total_parts=None, warning=None): warning) -IORequest = namedtuple('IORequest', ['filename', 'offset', 'data']) +IORequest = namedtuple('IORequest', + ['filename', 'offset', 'data', 'is_stream']) # Used to signal that IO for the filename is finished, and that # any associated resources may be cleaned up. IOCloseRequest = namedtuple('IOCloseRequest', ['filename']) diff --git a/awscli/examples/s3/cp.rst b/awscli/examples/s3/cp.rst index 6bdf25dc0dc1..1fe488bc7751 100644 --- a/awscli/examples/s3/cp.rst +++ b/awscli/examples/s3/cp.rst @@ -101,3 +101,15 @@ Output:: upload: file.txt to s3://mybucket/file.txt +**Uploading a local file stream to S3** + +The following ``cp`` command uploads a local file stream from standard input to a specified bucket and key:: + + aws s3 cp - s3://mybucket/stream.txt + + +**Downloading a S3 object as a local file stream** + +The following ``cp`` command downloads a S3 object locally as a stream to standard output:: + + aws s3 cp s3://mybucket/stream.txt - diff --git a/awscli/testutils.py b/awscli/testutils.py index b6f4bd2abcc0..2f4b55e018c1 100644 --- a/awscli/testutils.py +++ b/awscli/testutils.py @@ -395,7 +395,7 @@ def _escape_quotes(command): def aws(command, collect_memory=False, env_vars=None, - wait_for_finish=True): + wait_for_finish=True, input_data=None): """Run an aws command. This help function abstracts the differences of running the "aws" @@ -421,7 +421,7 @@ def aws(command, collect_memory=False, env_vars=None, else: aws_command = 'python %s' % get_aws_cmd() full_command = '%s %s' % (aws_command, command) - stdout_encoding = _get_stdout_encoding() + stdout_encoding = get_stdout_encoding() if isinstance(full_command, six.text_type) and not six.PY3: full_command = full_command.encode(stdout_encoding) INTEG_LOG.debug("Running command: %s", full_command) @@ -429,13 +429,16 @@ def aws(command, collect_memory=False, env_vars=None, env['AWS_DEFAULT_REGION'] = "us-east-1" if env_vars is not None: env = env_vars - process = Popen(full_command, stdout=PIPE, stderr=PIPE, shell=True, - env=env) + process = Popen(full_command, stdout=PIPE, stderr=PIPE, stdin=PIPE, + shell=True, env=env) if not wait_for_finish: return process memory = None if not collect_memory: - stdout, stderr = process.communicate() + kwargs = {} + if input_data: + kwargs = {'input': input_data} + stdout, stderr = process.communicate(**kwargs) else: stdout, stderr, memory = _wait_and_collect_mem(process) return Result(process.returncode, @@ -444,7 +447,7 @@ def aws(command, collect_memory=False, env_vars=None, memory) -def _get_stdout_encoding(): +def get_stdout_encoding(): encoding = getattr(sys.__stdout__, 'encoding', None) if encoding is None: encoding = 'utf-8' diff --git a/tests/integration/customizations/s3/test_plugin.py b/tests/integration/customizations/s3/test_plugin.py index a58675a863f2..46f57ea11f1c 100644 --- a/tests/integration/customizations/s3/test_plugin.py +++ b/tests/integration/customizations/s3/test_plugin.py @@ -28,7 +28,7 @@ import botocore.session import six -from awscli.testutils import unittest, FileCreator +from awscli.testutils import unittest, FileCreator, get_stdout_encoding from awscli.testutils import aws as _aws from tests.unit.customizations.s3 import create_bucket as _create_bucket from awscli.customizations.s3 import constants @@ -44,12 +44,13 @@ def cd(directory): os.chdir(original) -def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True): +def aws(command, collect_memory=False, env_vars=None, wait_for_finish=True, + input_data=None): if not env_vars: env_vars = os.environ.copy() env_vars['AWS_DEFAULT_REGION'] = "us-west-2" return _aws(command, collect_memory=collect_memory, env_vars=env_vars, - wait_for_finish=wait_for_finish) + wait_for_finish=wait_for_finish, input_data=input_data) class BaseS3CLICommand(unittest.TestCase): @@ -1331,5 +1332,99 @@ def test_sync_file_with_spaces(self): self.assertEqual(p2.rc, 0) +class TestStreams(BaseS3CLICommand): + def test_upload(self): + """ + This tests uploading a small stream from stdin. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + 'This is a test') + + def test_unicode_upload(self): + """ + This tests being able to upload unicode from stdin. + """ + unicode_str = u'\u00e9 This is a test' + byte_str = unicode_str.encode('utf-8') + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=byte_str) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assertEqual(self.get_key_contents(bucket_name, 'stream'), + unicode_str) + + def test_multipart_upload(self): + """ + This tests the ability to multipart upload streams from stdin. + The data has some unicode in it to avoid having to do a seperate + multipart upload test just for unicode. + """ + + bucket_name = self.create_bucket() + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + self.assertTrue(self.key_exists(bucket_name, 'stream')) + self.assert_key_contents_equal(bucket_name, 'stream', data) + + def test_download(self): + """ + This tests downloading a small stream from stdout. + """ + bucket_name = self.create_bucket() + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=b'This is a test') + self.assert_no_errors(p) + + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, 'This is a test') + + def test_unicode_download(self): + """ + This tests downloading a small unicode stream from stdout. + """ + bucket_name = self.create_bucket() + + data = u'\u00e9 This is a test' + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + self.assert_no_errors(p) + + # Downloading the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + def test_multipart_download(self): + """ + This tests the ability to multipart download streams to stdout. + The data has some unicode in it to avoid having to do a seperate + multipart download test just for unicode. + """ + bucket_name = self.create_bucket() + + # First lets upload some data via streaming since + # its faster and we do not have to write to a file! + data = u'\u00e9bcd' * (1024 * 1024 * 10) + data_encoded = data.encode('utf-8') + p = aws('s3 cp - s3://%s/stream' % bucket_name, + input_data=data_encoded) + + # Download the unicode stream to standard out. + p = aws('s3 cp s3://%s/stream -' % bucket_name) + self.assert_no_errors(p) + self.assertEqual(p.stdout, data_encoded.decode(get_stdout_encoding())) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/__init__.py b/tests/unit/customizations/s3/__init__.py index 5ada082edd3e..6986caa4fa6d 100644 --- a/tests/unit/customizations/s3/__init__.py +++ b/tests/unit/customizations/s3/__init__.py @@ -16,7 +16,7 @@ import string import six -from mock import patch +from mock import patch, Mock class S3HandlerBaseTest(unittest.TestCase): @@ -188,3 +188,24 @@ def list_buckets(session): html_response, response_data = operation.call(endpoint) contents = response_data['Buckets'] return contents + + +class MockStdIn(object): + """ + This class patches stdin in order to write a stream of bytes into + stdin. + """ + def __init__(self, input_bytes=b''): + input_data = six.BytesIO(input_bytes) + if six.PY3: + mock_object = Mock() + mock_object.buffer = input_data + else: + mock_object = input_data + self._patch = patch('sys.stdin', mock_object) + + def __enter__(self): + self._patch.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + self._patch.__exit__() diff --git a/tests/unit/customizations/s3/test_executor.py b/tests/unit/customizations/s3/test_executor.py index 311e718b5a1a..1a559562ec32 100644 --- a/tests/unit/customizations/s3/test_executor.py +++ b/tests/unit/customizations/s3/test_executor.py @@ -15,6 +15,7 @@ import shutil import six from six.moves import queue +import sys import mock @@ -41,7 +42,7 @@ def tearDown(self): shutil.rmtree(self.temp_dir) def test_handles_io_request(self): - self.queue.put(IORequest(self.filename, 0, b'foobar')) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -49,8 +50,8 @@ def test_handles_io_request(self): self.assertEqual(f.read(), b'foobar') def test_out_of_order_io_requests(self): - self.queue.put(IORequest(self.filename, 6, b'morestuff')) - self.queue.put(IORequest(self.filename, 0, b'foobar')) + self.queue.put(IORequest(self.filename, 6, b'morestuff', False)) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) self.io_thread.run() @@ -60,8 +61,8 @@ def test_out_of_order_io_requests(self): def test_multiple_files_in_queue(self): second_file = os.path.join(self.temp_dir, 'bar') open(second_file, 'w').close() - self.queue.put(IORequest(self.filename, 0, b'foobar')) - self.queue.put(IORequest(second_file, 0, b'otherstuff')) + self.queue.put(IORequest(self.filename, 0, b'foobar', False)) + self.queue.put(IORequest(second_file, 0, b'otherstuff', False)) self.queue.put(IOCloseRequest(second_file)) self.queue.put(IOCloseRequest(self.filename)) self.queue.put(ShutdownThreadRequest()) @@ -72,6 +73,20 @@ def test_multiple_files_in_queue(self): with open(second_file, 'rb') as f: self.assertEqual(f.read(), b'otherstuff') + def test_stream_requests(self): + # Test that offset has no affect on the order in which requests + # are written to stdout. The order of requests for a stream are + # first in first out. + self.queue.put(IORequest('nonexistant-file', 10, b'foobar', True)) + self.queue.put(IORequest('nonexistant-file', 6, b'otherstuff', True)) + # The thread should not try to close the file name because it is + # writing to stdout. If it does, the thread will fail because + # the file does not exist. + self.queue.put(ShutdownThreadRequest()) + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.io_thread.run() + self.assertEqual(mock_stdout.getvalue(), 'foobarotherstuff') + class TestExecutor(unittest.TestCase): def test_shutdown_does_not_hang(self): @@ -84,7 +99,8 @@ class FloodIOQueueTask(object): def __call__(self): for i in range(50): - executor.write_queue.put(IORequest(f.name, 0, b'foobar')) + executor.write_queue.put(IORequest(f.name, 0, + b'foobar', False)) executor.submit(FloodIOQueueTask()) executor.initiate_shutdown() executor.wait_until_shutdown() diff --git a/tests/unit/customizations/s3/test_fileinfo.py b/tests/unit/customizations/s3/test_fileinfo.py index 48a6651f42fb..bbee735fa047 100644 --- a/tests/unit/customizations/s3/test_fileinfo.py +++ b/tests/unit/customizations/s3/test_fileinfo.py @@ -21,6 +21,8 @@ from awscli.testutils import unittest from awscli.customizations.s3 import fileinfo +from awscli.customizations.s3.utils import MD5Error +from awscli.customizations.s3.fileinfo import FileInfo class TestSaveFile(unittest.TestCase): @@ -58,3 +60,25 @@ def test_makedir_other_exception(self, makedirs): fileinfo.save_file(self.filename, self.response_data, self.last_update) self.assertFalse(os.path.isfile(self.filename)) + + def test_stream_file(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + fileinfo.save_file(None, self.response_data, None, True) + self.assertEqual(mock_stdout.getvalue(), "foobar") + + def test_stream_file_md5_error(self): + with mock.patch('sys.stdout', new=six.StringIO()) as mock_stdout: + self.response_data['ETag'] = '"0"' + with self.assertRaises(MD5Error): + fileinfo.save_file(None, self.response_data, None, True) + # Make sure nothing is written to stdout. + self.assertEqual(mock_stdout.getvalue(), "") + + +class TestSetSizeFromS3(unittest.TestCase): + def test_set_size_from_s3(self): + file_info = FileInfo(src="bucket/key", endpoint=None) + with mock.patch('awscli.customizations.s3.fileinfo.operate') as op_mock: + op_mock.return_value = ({'ContentLength': 5}, None) + file_info.set_size_from_s3() + self.assertEqual(file_info.size, 5) diff --git a/tests/unit/customizations/s3/test_fileinfobuilder.py b/tests/unit/customizations/s3/test_fileinfobuilder.py index 439c006ad136..7d235e5728de 100644 --- a/tests/unit/customizations/s3/test_fileinfobuilder.py +++ b/tests/unit/customizations/s3/test_fileinfobuilder.py @@ -22,7 +22,8 @@ class TestFileInfoBuilder(unittest.TestCase): def test_info_setter(self): info_setter = FileInfoBuilder(service='service', endpoint='endpoint', source_endpoint='source_endpoint', - parameters='parameters') + parameters='parameters', + is_stream='is_stream') files = [FileStat(src='src', dest='dest', compare_key='compare_key', size='size', last_update='last_update', src_type='src_type', dest_type='dest_type', diff --git a/tests/unit/customizations/s3/test_s3handler.py b/tests/unit/customizations/s3/test_s3handler.py index 5a5b416f8ad0..2105d3495770 100644 --- a/tests/unit/customizations/s3/test_s3handler.py +++ b/tests/unit/customizations/s3/test_s3handler.py @@ -14,15 +14,19 @@ import os import random import sys -from awscli.testutils import unittest +import mock + +from awscli.testutils import unittest from awscli import EnvironmentVariables -from awscli.customizations.s3.s3handler import S3Handler +from awscli.customizations.s3.s3handler import S3Handler, S3StreamHandler from awscli.customizations.s3.fileinfo import FileInfo +from awscli.customizations.s3.tasks import CreateMultipartUploadTask, \ + UploadPartTask, CreateLocalFileTask from tests.unit.customizations.s3.fake_session import FakeSession from tests.unit.customizations.s3 import make_loc_files, clean_loc_files, \ make_s3_files, s3_cleanup, create_bucket, list_contents, list_buckets, \ - S3HandlerBaseTest + S3HandlerBaseTest, MockStdIn class S3HandlerTestDeleteList(S3HandlerBaseTest): @@ -612,5 +616,167 @@ def test_bucket(self): self.assertEqual(orig_number_buckets, number_buckets) +class TestStreams(S3HandlerBaseTest): + def setUp(self): + super(TestStreams, self).setUp() + self.session = FakeSession() + self.service = self.session.get_service('s3') + self.endpoint = self.service.get_endpoint('us-east-1') + self.params = {'is_stream': True, 'region': 'us-east-1'} + + def test_pull_from_stream(self): + s3handler = S3StreamHandler(self.session, self.params, chunksize=2) + input_to_stdin = b'This is a test' + size = len(input_to_stdin) + # Retrieve the entire string. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin) + # Ensure the function exits when there is nothing to read. + with MockStdIn(): + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, b'') + # Ensure the function does not grab too much out of stdin. + with MockStdIn(input_to_stdin): + payload, is_amount_requested = s3handler._pull_from_stream(size-2) + data = payload.read() + self.assertTrue(is_amount_requested) + self.assertEqual(data, input_to_stdin[:-2]) + # Retrieve the rest of standard in. + payload, is_amount_requested = s3handler._pull_from_stream(size) + data = payload.read() + self.assertFalse(is_amount_requested) + self.assertEqual(data, input_to_stdin[-2:]) + + def test_upload_stream_not_multipart_task(self): + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # No multipart upload should have been submitted. + self.assertEqual(len(submitted_tasks), 1) + self.assertEqual(submitted_tasks[0][0][0].payload.read(), + b'bar') + + def test_upload_stream_is_multipart_task(self): + s3handler = S3StreamHandler(self.session, self.params, + multi_threshold=1) + s3handler.executor = mock.Mock() + fileinfos = [FileInfo('filename', operation_name='upload', + is_stream=True, size=0)] + with MockStdIn(b'bar'): + s3handler._enqueue_tasks(fileinfos) + submitted_tasks = s3handler.executor.submit.call_args_list + # This should be a multipart upload so multiple tasks + # should have been submitted. + self.assertEqual(len(submitted_tasks), 4) + self.assertEqual(submitted_tasks[1][0][0]._payload.read(), + b'b') + self.assertEqual(submitted_tasks[2][0][0]._payload.read(), + b'ar') + + def test_upload_stream_with_expected_size(self): + self.params['expected_size'] = 100000 + # With this large of expected size, the chunksize of 2 will have + # to change. + s3handler = S3StreamHandler(self.session, self.params, chunksize=2) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + with MockStdIn(b'bar'): + s3handler._enqueue_multipart_upload_tasks(fileinfo, b'') + submitted_tasks = s3handler.executor.submit.call_args_list + # Determine what the chunksize was changed to from one of the + # UploadPartTasks. + changed_chunk_size = submitted_tasks[1][0][0]._chunk_size + # New chunksize should have a total parts under 1000. + self.assertTrue(100000/changed_chunk_size < 1000) + + def test_upload_stream_enqueue_upload_task(self): + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='upload', + is_stream=True) + stdin_input = b'This is a test' + with MockStdIn(stdin_input): + num_parts = s3handler._enqueue_upload_tasks(None, 2, mock.Mock(), + fileinfo, + UploadPartTask) + submitted_tasks = s3handler.executor.submit.call_args_list + # Ensure the returned number of parts is correct. + self.assertEqual(num_parts, len(submitted_tasks) + 1) + # Ensure the number of tasks uploaded are as expected + self.assertEqual(len(submitted_tasks), 8) + index = 0 + for i in range(len(submitted_tasks)-1): + self.assertEqual(submitted_tasks[i][0][0]._payload.read(), + stdin_input[index:index+2]) + index += 2 + # Ensure that the last part is an empty string as expected. + self.assertEqual(submitted_tasks[7][0][0]._payload.read(), b'') + + def test_enqueue_upload_single_part_task_stream(self): + """ + This test ensures that a payload gets attached to a task when + it is submitted to the executor. + """ + s3handler = S3StreamHandler(self.session, self.params) + s3handler.executor = mock.Mock() + mock_task_class = mock.Mock() + s3handler._enqueue_upload_single_part_task( + part_number=1, chunk_size=2, upload_context=None, + filename=None, task_class=mock_task_class, + payload=b'This is a test' + ) + args, kwargs = mock_task_class.call_args + self.assertIn('payload', kwargs.keys()) + self.assertEqual(kwargs['payload'], b'This is a test') + + def test_enqueue_multipart_download_stream(self): + """ + This test ensures the right calls are made in ``_enqueue_tasks()`` + if the file should be a multipart download. + """ + s3handler = S3StreamHandler(self.session, self.params, + multi_threshold=5) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True) + with mock.patch('awscli.customizations.s3.s3handler' + '.S3StreamHandler._enqueue_range_download_tasks') as \ + mock_enqueue_range_tasks: + with mock.patch('awscli.customizations.s3.fileinfo.FileInfo' + '.set_size_from_s3') as mock_set_size_from_s3: + # Set the file size to something larger than the multipart + # threshold. + fileinfo.size = 100 + # Run the main enqueue function. + s3handler._enqueue_tasks([fileinfo]) + # Assert that the size of the ``FileInfo`` object was set + # if we are downloading a stream. + self.assertTrue(mock_set_size_from_s3.called) + # Ensure that this download would have been a multipart + # download. + self.assertTrue(mock_enqueue_range_tasks.called) + + def test_enqueue_range_download_tasks_stream(self): + s3handler = S3StreamHandler(self.session, self.params, chunksize=100) + s3handler.executor = mock.Mock() + fileinfo = FileInfo('filename', operation_name='download', + is_stream=True, size=100) + s3handler._enqueue_range_download_tasks(fileinfo) + # Ensure that no request was sent to make a file locally. + submitted_tasks = s3handler.executor.submit.call_args_list + self.assertNotEqual(type(submitted_tasks[0][0][0]), + CreateLocalFileTask) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/customizations/s3/test_subcommands.py b/tests/unit/customizations/s3/test_subcommands.py index 755ad4fdb675..fab4076b2752 100644 --- a/tests/unit/customizations/s3/test_subcommands.py +++ b/tests/unit/customizations/s3/test_subcommands.py @@ -174,12 +174,13 @@ def test_create_instructions(self): 'rb': ['s3_handler']} params = {'filters': True, 'region': 'us-east-1', 'endpoint_url': None, - 'verify_ssl': None} + 'verify_ssl': None, 'is_stream': False} for cmd in cmds: cmd_arc = CommandArchitecture(self.session, cmd, {'region': 'us-east-1', 'endpoint_url': None, - 'verify_ssl': None}) + 'verify_ssl': None, + 'is_stream': False}) cmd_arc.create_instructions() self.assertEqual(cmd_arc.instructions, instructions[cmd]) @@ -202,7 +203,8 @@ def test_run_cp_put(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -218,7 +220,8 @@ def test_error_on_same_line_as_status(self): 'src': local_file, 'dest': s3_file, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -241,7 +244,8 @@ def test_run_cp_get(self): 'src': s3_file, 'dest': local_file, 'filters': filters, 'paths_type': 's3local', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -258,7 +262,8 @@ def test_run_cp_copy(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'cp', params) cmd_arc.create_instructions() cmd_arc.run() @@ -275,7 +280,8 @@ def test_run_mv(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3s3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mv', params) cmd_arc.create_instructions() cmd_arc.run() @@ -292,7 +298,8 @@ def test_run_remove(self): 'src': s3_file, 'dest': s3_file, 'filters': filters, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rm', params) cmd_arc.create_instructions() cmd_arc.run() @@ -313,7 +320,8 @@ def test_run_sync(self): 'src': local_dir, 'dest': s3_prefix, 'filters': filters, 'paths_type': 'locals3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, - 'follow_symlinks': True, 'page_size': None} + 'follow_symlinks': True, 'page_size': None, + 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'sync', params) cmd_arc.create_instructions() cmd_arc.run() @@ -329,7 +337,7 @@ def test_run_mb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'mb', params) cmd_arc.create_instructions() cmd_arc.run() @@ -345,7 +353,7 @@ def test_run_rb(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -362,7 +370,7 @@ def test_run_rb_nonzero_rc(self): 'src': s3_prefix, 'dest': s3_prefix, 'paths_type': 's3', 'region': 'us-east-1', 'endpoint_url': None, 'verify_ssl': None, 'follow_symlinks': True, - 'page_size': None} + 'page_size': None, 'is_stream': False} cmd_arc = CommandArchitecture(self.session, 'rb', params) cmd_arc.create_instructions() rc = cmd_arc.run() @@ -473,6 +481,34 @@ def test_check_force(self): cmd_params.parameters['src'] = 's3://mybucket' cmd_params.check_force(None) + def test_validate_streaming_paths_upload(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['only_show_errors']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_streaming_paths_download(self): + parameters = {'src': 'localfile', 'dest': '-'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertTrue(cmd_params.parameters['is_stream']) + self.assertTrue(cmd_params.parameters['only_show_errors']) + self.assertFalse(cmd_params.parameters['dir_op']) + + def test_validate_no_streaming_paths(self): + parameters = {'src': 'localfile', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'cp', parameters, '') + cmd_params._validate_streaming_paths() + self.assertFalse(cmd_params.parameters['is_stream']) + + def test_validate_streaming_paths_error(self): + parameters = {'src': '-', 'dest': 's3://bucket'} + cmd_params = CommandParameters(self.session, 'sync', parameters, '') + with self.assertRaises(ValueError): + cmd_params._validate_streaming_paths() + class HelpDocTest(BaseAWSHelpOutputTest): def setUp(self): diff --git a/tests/unit/customizations/s3/test_tasks.py b/tests/unit/customizations/s3/test_tasks.py index 4451c85cb569..eda16f765778 100644 --- a/tests/unit/customizations/s3/test_tasks.py +++ b/tests/unit/customizations/s3/test_tasks.py @@ -22,6 +22,7 @@ from awscli.customizations.s3.tasks import CompleteDownloadTask from awscli.customizations.s3.tasks import DownloadPartTask from awscli.customizations.s3.tasks import MultipartUploadContext +from awscli.customizations.s3.tasks import MultipartDownloadContext from awscli.customizations.s3.tasks import UploadCancelledError from awscli.customizations.s3.tasks import print_operation from awscli.customizations.s3.tasks import RetriesExeededError @@ -163,6 +164,58 @@ def test_basic_threaded_parts(self): self.calls[2][1:], ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_streaming_threaded_parts(self): + # This is similar to the basic threaded parts test but instead + # the thread has to wait to know exactly how many parts are + # expected from the stream. This is indicated when the expected + # parts of the context changes from ... to an integer. + + self.context = MultipartUploadContext(expected_parts='...') + upload_part_thread = threading.Thread(target=self.upload_part, + args=(1,)) + # Once this thread starts it will immediately block. + self.start_thread(upload_part_thread) + + # Also, let's start the thread that will do the complete + # multipart upload. It will also block because it needs all + # the parts so it's blocked up the upload_part_thread. It also + # needs the upload_id so it's blocked on that as well. + complete_upload_thread = threading.Thread(target=self.complete_upload) + self.start_thread(complete_upload_thread) + + # Then finally the CreateMultipartUpload completes and we + # announce the upload id. + self.create_upload('my_upload_id') + # The complete upload thread should still be waiting for an expect + # parts number. + with self.call_lock: + was_completed = (len(self.calls) > 2) + + # The upload_part thread can now proceed as well as the complete + # multipart upload thread. + self.context.announce_total_parts(1) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # Make sure that the completed task was never called since it was + # waiting to announce the parts. + self.assertFalse(was_completed) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 3) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'create_multipart_upload') + self.assertEqual(self.calls[1][0], 'upload_part') + self.assertEqual(self.calls[2][0], 'complete_upload') + + # Verify the correct args were used. + self.assertEqual(self.calls[0][1], 'my_upload_id') + self.assertEqual(self.calls[1][1:], (1, 'my_upload_id')) + self.assertEqual( + self.calls[2][1:], + ('my_upload_id', [{'ETag': 'etag1', 'PartNumber': 1}])) + def test_randomized_stress_test(self): # Now given that we've verified the functionality from # the two tests above, we randomize the threading to ensure @@ -279,6 +332,7 @@ def setUp(self): self.filename.size = 10 * 1024 * 1024 self.filename.src = 'bucket/key' self.filename.dest = 'local/file' + self.filename.is_stream = False self.filename.service = self.service self.filename.operation_name = 'download' self.context = mock.Mock() @@ -325,9 +379,9 @@ def test_download_queues_io_properly(self): call_args_list = self.io_queue.put.call_args_list self.assertEqual(len(call_args_list), 2) self.assertEqual(call_args_list[0], - mock.call(('local/file', 0, b'foobar'))) + mock.call(('local/file', 0, b'foobar', False))) self.assertEqual(call_args_list[1], - mock.call(('local/file', 6, b'morefoobar'))) + mock.call(('local/file', 6, b'morefoobar', False))) def test_incomplete_read_is_retried(self): self.service.get_operation.return_value.call.side_effect = \ @@ -342,6 +396,61 @@ def test_incomplete_read_is_retried(self): self.service.get_operation.call_count) +class TestMultipartDownloadContext(unittest.TestCase): + def setUp(self): + self.context = MultipartDownloadContext(num_parts=2) + self.calls = [] + self.threads = [] + self.call_lock = threading.Lock() + self.caught_exception = None + + def tearDown(self): + self.join_threads() + + def join_threads(self): + for thread in self.threads: + thread.join() + + def download_stream_part(self, part_number): + try: + self.context.wait_for_turn(part_number) + with self.call_lock: + self.calls.append(('download_part', str(part_number))) + self.context.done_with_turn() + except Exception as e: + self.caught_exception = e + return + + def start_thread(self, thread): + thread.start() + self.threads.append(thread) + + def test_stream_context(self): + part_thread = threading.Thread(target=self.download_stream_part, + args=(1,)) + # Once this thread starts it will immediately block becasue it is + # waiting for part zero to finish submitting its task. + self.start_thread(part_thread) + + # Now create the thread that should submit its task first. + part_thread2 = threading.Thread(target=self.download_stream_part, + args=(0,)) + self.start_thread(part_thread2) + self.join_threads() + + self.assertIsNone(self.caught_exception) + + # We can verify that the invariants still hold. + self.assertEqual(len(self.calls), 2) + # First there should be three calls, create, upload, complete. + self.assertEqual(self.calls[0][0], 'download_part') + self.assertEqual(self.calls[1][0], 'download_part') + + # Verify the correct order were used. + self.assertEqual(self.calls[0][1], '0') + self.assertEqual(self.calls[1][1], '1') + + class TestTaskOrdering(unittest.TestCase): def setUp(self): self.q = StablePriorityQueue(maxsize=10, max_priority=20) diff --git a/tests/unit/test_completer.py b/tests/unit/test_completer.py index d3f6843971d6..bd41e0e0f063 100644 --- a/tests/unit/test_completer.py +++ b/tests/unit/test_completer.py @@ -73,7 +73,8 @@ '--cache-control', '--content-type', '--content-disposition', '--source-region', '--content-encoding', '--content-language', - '--expires', '--grants', '--only-show-errors'] + '--expires', '--grants', '--only-show-errors', + '--expected-size'] + GLOBALOPTS)), ('aws s3 cp --quiet -', -1, set(['--no-guess-mime-type', '--dryrun', '--recursive', '--content-type', @@ -84,7 +85,8 @@ '--storage-class', '--sse', '--exclude', '--include', '--source-region', - '--grants', '--only-show-errors'] + '--grants', '--only-show-errors', + '--expected-size'] + GLOBALOPTS)), ('aws emr ', -1, set(['add-instance-groups', 'add-steps', 'add-tags', 'create-cluster', 'create-default-roles',