Skip to content

Commit

Permalink
Cleaned up the streaming code.
Browse files Browse the repository at this point in the history
This includes adding more tests, simplifying the code, and some PEP8 cleaning.
  • Loading branch information
kyleknap committed Sep 16, 2014
1 parent 19ea686 commit 9022a59
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 38 deletions.
1 change: 0 additions & 1 deletion awscli/customizations/s3/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@
MAX_SINGLE_UPLOAD_SIZE = 5 * (1024 ** 3)
MAX_UPLOAD_SIZE = 5 * (1024 ** 4)
MAX_QUEUE_SIZE = 1000
STREAM_INPUT_TIMEOUT = 0.1
3 changes: 2 additions & 1 deletion awscli/customizations/s3/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, num_threads, result_queue,
self.quiet = quiet
self.threads_list = []
self.write_queue = write_queue
self.print_thread = PrintThread(self.result_queue, self.quiet)
self.print_thread = PrintThread(self.result_queue,
self.quiet)
self.print_thread.daemon = True
self.io_thread = IOWriterThread(self.write_queue)

Expand Down
43 changes: 15 additions & 28 deletions awscli/customizations/s3/s3handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import six
from six.moves import queue
import sys
import time

from awscli.customizations.s3.constants import MULTI_THRESHOLD, CHUNKSIZE, \
NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE, STREAM_INPUT_TIMEOUT
NUM_THREADS, MAX_UPLOAD_SIZE, MAX_QUEUE_SIZE
from awscli.customizations.s3.utils import find_chunksize, \
operate, find_bucket_key, relative_path, PrintTask, create_warning
from awscli.customizations.s3.executor import Executor
Expand Down Expand Up @@ -349,8 +348,8 @@ class S3StreamHandler(S3Handler):
downloading streams.
"""

# This ensures that at most the number of multipart chunks
# waiting in the executor queue and in the threads is limited.
# This ensures that the number of multipart chunks waiting in the
# executor queue and in the threads is limited.
MAX_EXECUTOR_QUEUE_SIZE = 2
EXECUTOR_NUM_THREADS = 6

Expand All @@ -366,7 +365,9 @@ def _enqueue_tasks(self, files):
payload, is_multipart_task = \
self._pull_from_stream(self.multi_threshold)
else:
# Set the file size for the file object
# Set the file size for the ``FileInfo`` object since
# streams do not use a ``FileGenerator`` that usually
# determines the size.
filename.set_size_from_s3()
is_multipart_task = self._is_multipart_task(filename)
if is_multipart_task and not self.params['dryrun']:
Expand All @@ -388,36 +389,20 @@ def _enqueue_tasks(self, files):
total_parts += num_uploads
return total_files, total_parts

def _pull_from_stream(self, initial_amount_requested):
def _pull_from_stream(self, amount_requested):
"""
This function keeps pulling data from stdin until it hits the amount
requested or there is no more left ot pull in from stdin. The
This function pulls data from stdin until it hits the amount
requested or there is no more left to pull in from stdin. The
function wraps the data into a ``BytesIO`` object that is returned
along with a boolean telling whether the amount requested is
the amount returned.
"""
size = 0
amount_requested = initial_amount_requested
total_retries = 0
payload = b''
stream_filein = sys.stdin
if six.PY3:
stream_filein = sys.stdin.buffer
while True:
payload_chunk = stream_filein.read(amount_requested)
payload_chunk_size = len(payload_chunk)
payload += payload_chunk
size += payload_chunk_size
amount_requested -= payload_chunk_size
if payload_chunk_size == 0:
time.sleep(STREAM_INPUT_TIMEOUT)
total_retries += 1
else:
total_retries = 0
if amount_requested == 0 or total_retries == 5:
break
payload = stream_filein.read(amount_requested)
payload_file = six.BytesIO(payload)
return payload_file, size == initial_amount_requested
return payload_file, len(payload) == amount_requested

def _enqueue_multipart_tasks(self, filename, payload=None):
num_uploads = 1
Expand All @@ -435,7 +420,9 @@ def _enqueue_range_download_tasks(self, filename, remove_remote_file=False):
num_downloads = int(filename.size / chunksize)
context = tasks.MultipartDownloadContext(num_downloads)

# No file is needed for downloading a stream.
# No file is needed for downloading a stream. So just announce
# that it has been made since it is required for the context to
# begin downloading.
context.announce_file_created()

# Submit download part tasks to the executor.
Expand Down Expand Up @@ -478,7 +465,7 @@ def _enqueue_multipart_upload_tasks(self, filename, payload=None):
filename, tasks.UploadPartTask
)

# Submit a task to notify the
# Submit a task to notify the multipart upload is complete.
self._enqueue_upload_end_task(filename, upload_context)

return num_uploads
Expand Down
1 change: 1 addition & 0 deletions tests/unit/customizations/s3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self):
def tearDown(self):
self.wait_timeout_patch.stop()


def make_loc_files():
"""
This sets up the test by making a directory named some_directory. It
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/customizations/s3/test_fileinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from awscli.testutils import unittest
from awscli.customizations.s3 import fileinfo
from awscli.customizations.s3.utils import MD5Error
from awscli.customizations.s3.fileinfo import FileInfo


class TestSaveFile(unittest.TestCase):
Expand Down Expand Up @@ -72,3 +73,12 @@ def test_stream_file_md5_error(self):
fileinfo.save_file(None, self.response_data, None, True)
# Make sure nothing is written to stdout.
self.assertEqual(mock_stdout.getvalue(), "")


class TestSetSizeFromS3(unittest.TestCase):
def test_set_size_from_s3(self):
file_info = FileInfo(src="bucket/key", endpoint=None)
with mock.patch('awscli.customizations.s3.fileinfo.operate') as op_mock:
op_mock.return_value = ({'ContentLength': 5}, None)
file_info.set_size_from_s3()
self.assertEqual(file_info.size, 5)
36 changes: 28 additions & 8 deletions tests/unit/customizations/s3/test_s3handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,13 +623,6 @@ def setUp(self):
self.service = self.session.get_service('s3')
self.endpoint = self.service.get_endpoint('us-east-1')
self.params = {'is_stream': True, 'region': 'us-east-1'}
stream_timeout = 'awscli.customizations.s3.constants.STREAM_INPUT_TIMEOUT'
self.stream_timeout_patch = mock.patch(stream_timeout, 0.001)
self.stream_timeout_patch.start()

def tearDown(self):
super(TestStreams, self).tearDown()
self.stream_timeout_patch.stop()

def test_pull_from_stream(self):
s3handler = S3StreamHandler(self.session, self.params, chunksize=2)
Expand Down Expand Up @@ -674,7 +667,7 @@ def test_upload_stream_not_multipart_task(self):

def test_upload_stream_is_multipart_task(self):
s3handler = S3StreamHandler(self.session, self.params,
multi_threshold=1)
multi_threshold=1)
s3handler.executor = mock.Mock()
fileinfos = [FileInfo('filename', operation_name='upload',
is_stream=True, size=0)]
Expand Down Expand Up @@ -746,6 +739,33 @@ def test_enqueue_upload_single_part_task_stream(self):
self.assertIn('payload', kwargs.keys())
self.assertEqual(kwargs['payload'], b'This is a test')

def test_enqueue_multipart_download_stream(self):
"""
This test ensures the right calls are made in ``_enqueue_tasks()``
if the file should be a multipart download.
"""
s3handler = S3StreamHandler(self.session, self.params,
multi_threshold=5)
s3handler.executor = mock.Mock()
fileinfo = FileInfo('filename', operation_name='download',
is_stream=True)
with mock.patch('awscli.customizations.s3.s3handler'
'.S3StreamHandler._enqueue_range_download_tasks') as \
mock_enqueue_range_tasks:
with mock.patch('awscli.customizations.s3.fileinfo.FileInfo'
'.set_size_from_s3') as mock_set_size_from_s3:
# Set the file size to something larger than the multipart
# threshold.
fileinfo.size = 100
# Run the main enqueue function.
s3handler._enqueue_tasks([fileinfo])
# Assert that the size of the ``FileInfo`` object was set
# if we are downloading a stream.
self.assertTrue(mock_set_size_from_s3.called)
# Ensure that this download would have been a multipart
# download.
self.assertTrue(mock_enqueue_range_tasks.called)

def test_enqueue_range_download_tasks_stream(self):
s3handler = S3StreamHandler(self.session, self.params, chunksize=100)
s3handler.executor = mock.Mock()
Expand Down

0 comments on commit 9022a59

Please sign in to comment.