diff --git a/awscli/customizations/s3/tasks.py b/awscli/customizations/s3/tasks.py index 1c418541841f..9b079e66da3c 100644 --- a/awscli/customizations/s3/tasks.py +++ b/awscli/customizations/s3/tasks.py @@ -2,9 +2,11 @@ import math import os import time +import socket import threading from botocore.vendored import requests +from botocore.exceptions import IncompleteReadError from awscli.customizations.s3.utils import find_bucket_key, MD5Error, \ operate, ReadFileChunk, relative_path @@ -21,6 +23,10 @@ class DownloadCancelledError(Exception): pass +class RetriesExeededError(Exception): + pass + + def print_operation(filename, failed, dryrun=False): """ Helper function used to print out what an operation did and whether @@ -292,17 +298,30 @@ class DownloadPartTask(object): # Amount to read from response body at a time. ITERATE_CHUNK_SIZE = 1024 * 1024 + READ_TIMEOUT = 60 + TOTAL_ATTEMPTS = 5 def __init__(self, part_number, chunk_size, result_queue, service, - filename, context): + filename, context, open=open): self._part_number = part_number self._chunk_size = chunk_size self._result_queue = result_queue self._filename = filename self._service = filename.service self._context = context + self._open = open def __call__(self): + try: + self._download_part() + except Exception as e: + LOGGER.debug( + 'Exception caught downloading byte range: %s', + e, exc_info=True) + self._context.cancel() + raise e + + def _download_part(self): total_file_size = self._filename.size start_range = self._part_number * self._chunk_size if self._part_number == int(total_file_size / self._chunk_size) - 1: @@ -315,34 +334,42 @@ def __call__(self): bucket, key = find_bucket_key(self._filename.src) params = {'endpoint': self._filename.endpoint, 'bucket': bucket, 'key': key, 'range': range_param} - try: - LOGGER.debug("Making GetObject requests with byte range: %s", - range_param) - response_data, http = operate(self._service, 'GetObject', - params) - LOGGER.debug("Response received from GetObject") - body = response_data['Body'] - self._write_to_file(body) - self._context.announce_completed_part(self._part_number) - - message = print_operation(self._filename, 0) - total_parts = int(self._filename.size / self._chunk_size) - result = {'message': message, 'error': False, - 'total_parts': total_parts} - self._result_queue.put(result) - except Exception as e: - LOGGER.debug( - 'Exception caught downloading byte range: %s', - e, exc_info=True) - self._context.cancel() - raise e + for i in range(self.TOTAL_ATTEMPTS): + try: + LOGGER.debug("Making GetObject requests with byte range: %s", + range_param) + response_data, http = operate(self._service, 'GetObject', + params) + LOGGER.debug("Response received from GetObject") + body = response_data['Body'] + self._write_to_file(body) + self._context.announce_completed_part(self._part_number) + + message = print_operation(self._filename, 0) + total_parts = int(self._filename.size / self._chunk_size) + result = {'message': message, 'error': False, + 'total_parts': total_parts} + self._result_queue.put(result) + return + except (socket.timeout, socket.error) as e: + LOGGER.debug("Socket timeout caught, retrying request, " + "(attempt %s / %s)", i, self.TOTAL_ATTEMPTS, + exc_info=True) + continue + except IncompleteReadError as e: + LOGGER.debug("Incomplete read detected: %s, (attempt %s / %s)", + e, i, self.TOTAL_ATTEMPTS) + continue + raise RetriesExeededError("Maximum number of attempts exceeded: %s" % + self.TOTAL_ATTEMPTS) def _write_to_file(self, body): self._context.wait_for_file_created() LOGGER.debug("Writing part number %s to file: %s", self._part_number, self._filename.dest) iterate_chunk_size = self.ITERATE_CHUNK_SIZE - with open(self._filename.dest, 'rb+') as f: + body.set_socket_timeout(self.READ_TIMEOUT) + with self._open(self._filename.dest, 'rb+') as f: f.seek(self._part_number * self._chunk_size) current = body.read(iterate_chunk_size) while current: @@ -352,7 +379,6 @@ def _write_to_file(self, body): self._part_number, self._filename.dest) - class CreateMultipartUploadTask(BasicTask): def __init__(self, session, filename, parameters, result_queue, upload_context): diff --git a/tests/unit/customizations/s3/fake_session.py b/tests/unit/customizations/s3/fake_session.py index 5eaecbe7b4ef..c902ec2712fa 100644 --- a/tests/unit/customizations/s3/fake_session.py +++ b/tests/unit/customizations/s3/fake_session.py @@ -262,6 +262,7 @@ def get_object(self, kwargs): else: body = body[int(beginning):(int(end) + 1)] mock_response = BytesIO(body) + mock_response.set_socket_timeout = Mock() response_data['Body'] = mock_response etag = self.session.s3[bucket][key]['ETag'] response_data['ETag'] = etag + '--' diff --git a/tests/unit/customizations/s3/test_tasks.py b/tests/unit/customizations/s3/test_tasks.py index 8fa1892e7b7e..e9177706bf36 100644 --- a/tests/unit/customizations/s3/test_tasks.py +++ b/tests/unit/customizations/s3/test_tasks.py @@ -14,10 +14,15 @@ import random import threading import mock +import socket +from botocore.exceptions import IncompleteReadError + +from awscli.customizations.s3.tasks import DownloadPartTask from awscli.customizations.s3.tasks import MultipartUploadContext from awscli.customizations.s3.tasks import UploadCancelledError from awscli.customizations.s3.tasks import print_operation +from awscli.customizations.s3.tasks import RetriesExeededError class TestMultipartUploadContext(unittest.TestCase): @@ -239,3 +244,56 @@ def test_print_operation(self): filename.dest_type = 's3' message = print_operation(filename, failed=False) self.assertIn(r'e:\foo', message) + + +class TestDownloadPartTask(unittest.TestCase): + def setUp(self): + self.result_queue = mock.Mock() + self.service = mock.Mock() + self.filename = mock.Mock() + self.filename.size = 10 * 1024 * 1024 + self.filename.src = 'bucket/key' + self.filename.dest = 'local/file' + self.filename.service = self.service + self.filename.operation_name = 'download' + self.context = mock.Mock() + self.open = mock.MagicMock() + + def test_socket_timeout_is_retried(self): + self.service.get_operation.return_value.call.side_effect = socket.error + task = DownloadPartTask(1, 1024 * 1024, self.result_queue, + self.service, self.filename, self.context) + # The mock is configured to keep raising a socket.error + # so we should cancel the download. + with self.assertRaises(RetriesExeededError): + task() + self.context.cancel.assert_called_with() + # And we retried the request multiple times. + self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS, + self.service.get_operation.call_count) + + def test_download_succeeds(self): + body = mock.Mock() + body.read.return_value = b'' + self.service.get_operation.return_value.call.side_effect = [ + socket.error, (mock.Mock(), {'Body': body})] + context = mock.Mock() + task = DownloadPartTask(1, 1024 * 1024, self.result_queue, + self.service, self.filename, self.context, + self.open) + task() + self.assertEqual(self.result_queue.put.call_count, 1) + # And we tried twice, the first one failed, the second one + # succeeded. + self.assertEqual(self.service.get_operation.call_count, 2) + + def test_incomplete_read_is_retried(self): + self.service.get_operation.return_value.call.side_effect = \ + IncompleteReadError(actual_bytes=1, expected_bytes=2) + task = DownloadPartTask(1, 1024 * 1024, self.result_queue, + self.service, self.filename, self.context) + with self.assertRaises(RetriesExeededError): + task() + self.context.cancel.assert_called_with() + self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS, + self.service.get_operation.call_count)