Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry intermittent S3 download failures #594

Merged
merged 1 commit into from
Jan 15, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions awscli/customizations/s3/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import math
import os
import time
import socket
import threading

from botocore.vendored import requests
from botocore.exceptions import IncompleteReadError

from awscli.customizations.s3.utils import find_bucket_key, MD5Error, \
operate, ReadFileChunk, relative_path
Expand All @@ -21,6 +23,10 @@ class DownloadCancelledError(Exception):
pass


class RetriesExeededError(Exception):
pass


def print_operation(filename, failed, dryrun=False):
"""
Helper function used to print out what an operation did and whether
Expand Down Expand Up @@ -292,17 +298,30 @@ class DownloadPartTask(object):

# Amount to read from response body at a time.
ITERATE_CHUNK_SIZE = 1024 * 1024
READ_TIMEOUT = 60
TOTAL_ATTEMPTS = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this come from? Given our typical number of retries is 6, I'm just curious.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe botocore uses 5 by default (except for special cases like ddb).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def __init__(self, part_number, chunk_size, result_queue, service,
filename, context):
filename, context, open=open):
self._part_number = part_number
self._chunk_size = chunk_size
self._result_queue = result_queue
self._filename = filename
self._service = filename.service
self._context = context
self._open = open

def __call__(self):
try:
self._download_part()
except Exception as e:
LOGGER.debug(
'Exception caught downloading byte range: %s',
e, exc_info=True)
self._context.cancel()
raise e

def _download_part(self):
total_file_size = self._filename.size
start_range = self._part_number * self._chunk_size
if self._part_number == int(total_file_size / self._chunk_size) - 1:
Expand All @@ -315,34 +334,42 @@ def __call__(self):
bucket, key = find_bucket_key(self._filename.src)
params = {'endpoint': self._filename.endpoint, 'bucket': bucket,
'key': key, 'range': range_param}
try:
LOGGER.debug("Making GetObject requests with byte range: %s",
range_param)
response_data, http = operate(self._service, 'GetObject',
params)
LOGGER.debug("Response received from GetObject")
body = response_data['Body']
self._write_to_file(body)
self._context.announce_completed_part(self._part_number)

message = print_operation(self._filename, 0)
total_parts = int(self._filename.size / self._chunk_size)
result = {'message': message, 'error': False,
'total_parts': total_parts}
self._result_queue.put(result)
except Exception as e:
LOGGER.debug(
'Exception caught downloading byte range: %s',
e, exc_info=True)
self._context.cancel()
raise e
for i in range(self.TOTAL_ATTEMPTS):
try:
LOGGER.debug("Making GetObject requests with byte range: %s",
range_param)
response_data, http = operate(self._service, 'GetObject',
params)
LOGGER.debug("Response received from GetObject")
body = response_data['Body']
self._write_to_file(body)
self._context.announce_completed_part(self._part_number)

message = print_operation(self._filename, 0)
total_parts = int(self._filename.size / self._chunk_size)
result = {'message': message, 'error': False,
'total_parts': total_parts}
self._result_queue.put(result)
return
except (socket.timeout, socket.error) as e:
LOGGER.debug("Socket timeout caught, retrying request, "
"(attempt %s / %s)", i, self.TOTAL_ATTEMPTS,
exc_info=True)
continue
except IncompleteReadError as e:
LOGGER.debug("Incomplete read detected: %s, (attempt %s / %s)",
e, i, self.TOTAL_ATTEMPTS)
continue
raise RetriesExeededError("Maximum number of attempts exceeded: %s" %
self.TOTAL_ATTEMPTS)

def _write_to_file(self, body):
self._context.wait_for_file_created()
LOGGER.debug("Writing part number %s to file: %s",
self._part_number, self._filename.dest)
iterate_chunk_size = self.ITERATE_CHUNK_SIZE
with open(self._filename.dest, 'rb+') as f:
body.set_socket_timeout(self.READ_TIMEOUT)
with self._open(self._filename.dest, 'rb+') as f:
f.seek(self._part_number * self._chunk_size)
current = body.read(iterate_chunk_size)
while current:
Expand All @@ -352,7 +379,6 @@ def _write_to_file(self, body):
self._part_number, self._filename.dest)



class CreateMultipartUploadTask(BasicTask):
def __init__(self, session, filename, parameters, result_queue,
upload_context):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/customizations/s3/fake_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def get_object(self, kwargs):
else:
body = body[int(beginning):(int(end) + 1)]
mock_response = BytesIO(body)
mock_response.set_socket_timeout = Mock()
response_data['Body'] = mock_response
etag = self.session.s3[bucket][key]['ETag']
response_data['ETag'] = etag + '--'
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/customizations/s3/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
import random
import threading
import mock
import socket

from botocore.exceptions import IncompleteReadError

from awscli.customizations.s3.tasks import DownloadPartTask
from awscli.customizations.s3.tasks import MultipartUploadContext
from awscli.customizations.s3.tasks import UploadCancelledError
from awscli.customizations.s3.tasks import print_operation
from awscli.customizations.s3.tasks import RetriesExeededError


class TestMultipartUploadContext(unittest.TestCase):
Expand Down Expand Up @@ -239,3 +244,56 @@ def test_print_operation(self):
filename.dest_type = 's3'
message = print_operation(filename, failed=False)
self.assertIn(r'e:\foo', message)


class TestDownloadPartTask(unittest.TestCase):
def setUp(self):
self.result_queue = mock.Mock()
self.service = mock.Mock()
self.filename = mock.Mock()
self.filename.size = 10 * 1024 * 1024
self.filename.src = 'bucket/key'
self.filename.dest = 'local/file'
self.filename.service = self.service
self.filename.operation_name = 'download'
self.context = mock.Mock()
self.open = mock.MagicMock()

def test_socket_timeout_is_retried(self):
self.service.get_operation.return_value.call.side_effect = socket.error
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context)
# The mock is configured to keep raising a socket.error
# so we should cancel the download.
with self.assertRaises(RetriesExeededError):
task()
self.context.cancel.assert_called_with()
# And we retried the request multiple times.
self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS,
self.service.get_operation.call_count)

def test_download_succeeds(self):
body = mock.Mock()
body.read.return_value = b''
self.service.get_operation.return_value.call.side_effect = [
socket.error, (mock.Mock(), {'Body': body})]
context = mock.Mock()
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context,
self.open)
task()
self.assertEqual(self.result_queue.put.call_count, 1)
# And we tried twice, the first one failed, the second one
# succeeded.
self.assertEqual(self.service.get_operation.call_count, 2)

def test_incomplete_read_is_retried(self):
self.service.get_operation.return_value.call.side_effect = \
IncompleteReadError(actual_bytes=1, expected_bytes=2)
task = DownloadPartTask(1, 1024 * 1024, self.result_queue,
self.service, self.filename, self.context)
with self.assertRaises(RetriesExeededError):
task()
self.context.cancel.assert_called_with()
self.assertEqual(DownloadPartTask.TOTAL_ATTEMPTS,
self.service.get_operation.call_count)