From 61e387e361fe20c323f430970a22301906c90e8f Mon Sep 17 00:00:00 2001 From: kyleknap Date: Wed, 22 Nov 2017 23:04:09 -0800 Subject: [PATCH 1/3] Add max_bandwidth option This limits the rate in which uploads and downloads can stream content to and from S3. The abstraction uses a leaky bucket to control bandwidth consumption. --- .../enhancement-maxbandwidth-87115.json | 5 + s3transfer/bandwidth.py | 409 ++++++++++++++++ s3transfer/download.py | 26 +- s3transfer/manager.py | 46 +- s3transfer/upload.py | 26 +- s3transfer/utils.py | 22 +- tests/functional/test_download.py | 29 ++ tests/functional/test_manager.py | 24 +- tests/functional/test_upload.py | 37 +- tests/unit/test_bandwidth.py | 458 ++++++++++++++++++ tests/unit/test_download.py | 17 + tests/unit/test_utils.py | 51 ++ 12 files changed, 1104 insertions(+), 46 deletions(-) create mode 100644 .changes/next-release/enhancement-maxbandwidth-87115.json create mode 100644 s3transfer/bandwidth.py create mode 100644 tests/unit/test_bandwidth.py diff --git a/.changes/next-release/enhancement-maxbandwidth-87115.json b/.changes/next-release/enhancement-maxbandwidth-87115.json new file mode 100644 index 00000000..f730339e --- /dev/null +++ b/.changes/next-release/enhancement-maxbandwidth-87115.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``max_bandwidth``", + "description": "Add ability to set maximum bandwidth consumption for streaming of S3 uploads and downloads" +} diff --git a/s3transfer/bandwidth.py b/s3transfer/bandwidth.py new file mode 100644 index 00000000..cb693e28 --- /dev/null +++ b/s3transfer/bandwidth.py @@ -0,0 +1,409 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import time +import threading + + +class RequestExceededException(Exception): + def __init__(self, requested_amt, retry_time): + """Error when requested amount exceeds what is allowed + + The request that raised this error should be retried after waiting + the time specified by ``retry_time``. + + :type requested_amt: int + :param requested_amt: The originally requested byte amount + + :type retry_time: float + :param retry_time: The length in time to wait to retry for the + requested amount + """ + self.requested_amt = requested_amt + self.retry_time = retry_time + msg = ( + 'Request amount %s exceeded the amount available. Retry in %s' % ( + requested_amt, retry_time) + ) + super(RequestExceededException, self).__init__(msg) + + +class RequestToken(object): + """A token to pass as an identifier when consuming from the LeakyBucket""" + pass + + +class TimeUtils(object): + def time(self): + """Get the current time back + + :rtype: float + :returns: The current time in seconds + """ + return time.time() + + def sleep(self, value): + """Sleep for a designated time + + :type value: float + :param value: The time to sleep for in seconds + """ + return time.sleep(value) + + +class BandwidthLimiter(object): + def __init__(self, token_bucket, time_utils=None): + """Limits bandwidth for shared S3 transfers + + :type leaky_bucket: LeakyBucket + :param token_bucket: The leaky bucket to use limit bandwidth + + :type time_utils: TimeUtils + :param time_utils: Time utility to use for interacting with time. + """ + self._token_bucket = token_bucket + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + + def get_bandwith_limited_stream(self, fileobj, transfer_coordinator, + enabled=True): + """Wraps a fileobj in a bandwidth limited stream wrapper + + :type fileobj: file-like obj + :param fileobj: The file-like obj to wrap + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type enabled: boolean + :param enabled: Whether bandwidth limiting should be enabled to start + """ + stream = BandwidthLimitedStream( + fileobj, self._token_bucket, transfer_coordinator, + self._time_utils) + if not enabled: + stream.disable_bandwidth_limiting() + return stream + + +class BandwidthLimitedStream(object): + def __init__(self, fileobj, leaky_bucket, transfer_coordinator, + time_utils=None, bytes_threshold=256 * 1024): + """Limits bandwidth for reads on a wrapped stream + + :type fileobj: file-like object + :param fileobj: The file like object to wrap + + :type token_bucket: LeakyBucket + :param token_bucket: The leaky bucket to use to throttle reads on + the stream + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + """ + self._fileobj = fileobj + self._leaky_bucket = leaky_bucket + self._transfer_coordinator = transfer_coordinator + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._bandwidth_limiting_enabled = True + self._request_token = RequestToken() + self._bytes_seen = 0 + self._bytes_threshold = bytes_threshold + + def enable_bandwidth_limiting(self): + """Enable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = True + + def disable_bandwidth_limiting(self): + """Disable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = False + + def read(self, amount): + """Read a specified amount + + Reads will only be throttled if bandwidth limiting is enabled. + """ + if not self._bandwidth_limiting_enabled: + return self._fileobj.read(amount) + + # We do not want to be calling consume on every read as the read + # amounts can be small causing the lock of the leaky bucket to + # introduce noticeable overhead. So instead we keep track of + # how many bytes we have seen and only call consume once we pass a + # certain threshold. + self._bytes_seen += amount + if self._bytes_seen < self._bytes_threshold: + return self._fileobj.read(amount) + + self._consume_through_leaky_bucket() + return self._fileobj.read(amount) + + def _consume_through_leaky_bucket(self): + # NOTE: If the read amonut on the stream are high, it will result + # in large bursty behavior as there is not an interface for partial + # reads. However given the read's on this abstraction are at most 256KB + # (via downloads), it reduces the burstiness to be small KB bursts at + # worst. + while not self._transfer_coordinator.exception: + try: + self._leaky_bucket.consume( + self._bytes_seen, self._request_token) + self._bytes_seen = 0 + return + except RequestExceededException as e: + self._time_utils.sleep(e.retry_time) + else: + raise self._transfer_coordinator.exception + + def signal_transferring(self): + """Signal that data being read is being transferred to S3""" + self.enable_bandwidth_limiting() + + def signal_not_transferring(self): + """Signal that data being read is not being transferred to S3""" + self.disable_bandwidth_limiting() + + def seek(self, where): + self._fileobj.seek(where) + + def tell(self): + return self._fileobj.tell() + + def close(self): + if self._bandwidth_limiting_enabled and self._bytes_seen: + # This handles the case where the file is small enough to never + # trigger the threshold and thus is never subjugated to the + # leaky bucket on read(). This specifically happens for small + # uploads. So instead to account for those bytes, have + # it go through the leaky bucket when the file gets closed. + self._consume_through_leaky_bucket() + self._fileobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class LeakyBucket(object): + def __init__(self, max_rate, time_utils=None, rate_tracker=None, + consumption_scheduler=None): + """A leaky bucket abstraction to limit bandwidth consumption + + :type rate: int + :type rate: The maximum rate to allow. This rate is in terms of + bytes per second. + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + + :type rate_tracker: BandwidthRateTracker + :param rate_tracker: Tracks bandwidth consumption + + :type consumption_scheduler: ConsumptionScheduler + :param consumption_scheduler: Schedules consumption retries when + necessary + """ + self._max_rate = float(max_rate) + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._lock = threading.Lock() + self._rate_tracker = rate_tracker + if rate_tracker is None: + self._rate_tracker = BandwidthRateTracker() + self._consumption_scheduler = consumption_scheduler + if consumption_scheduler is None: + self._consumption_scheduler = ConsumptionScheduler() + + def consume(self, amt, request_token): + """Consume an a requested amount + + :type amt: int + :param amt: The amount of bytes to request to consume + + :type request_token: RequestToken + :param request_token: The token associated to the consumption + request that is used to identify the request. So if a + RequestExceededException is raised the token should be used + in subsequent retry consume() request. + + :raises RequestExceededException: If the consumption amount would + exceed the maximum allocated bandwidth + + :rtype: int + :returns: The amount consumed + """ + with self._lock: + time_now = self._time_utils.time() + if self._consumption_scheduler.is_scheduled(request_token): + return self._release_requested_amt_for_scheduled_request( + amt, request_token, time_now) + elif self._projected_to_exceed_max_rate(amt, time_now): + self._raise_request_exceeded_exception( + amt, request_token, time_now) + else: + return self._release_requested_amt(amt, time_now) + + def _projected_to_exceed_max_rate(self, amt, time_now): + projected_rate = self._rate_tracker.get_projected_rate(amt, time_now) + return projected_rate > self._max_rate + + def _release_requested_amt_for_scheduled_request(self, amt, request_token, + time_now): + self._consumption_scheduler.process_scheduled_consumption( + request_token) + return self._release_requested_amt(amt, time_now) + + def _raise_request_exceeded_exception(self, amt, request_token, time_now): + allocated_time = amt/float(self._max_rate) + retry_time = self._consumption_scheduler.schedule_consumption( + amt, request_token, allocated_time) + raise RequestExceededException( + requested_amt=amt, retry_time=retry_time) + + def _release_requested_amt(self, amt, time_now): + self._rate_tracker.record_consumption_rate(amt, time_now) + return amt + + +class ConsumptionScheduler(object): + def __init__(self): + """Schedules when to consume a desired amount""" + self._tokens_to_scheduled_consumption = {} + self._total_wait = 0 + + def is_scheduled(self, token): + """Indicates if a consumption request has been scheduled + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + return token in self._tokens_to_scheduled_consumption + + def schedule_consumption(self, amt, token, time_to_consume): + """Schedules a wait time to be able to consume an amount + + :type amt: int + :param amt: The amount of bytes scheduled to be consumed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + + :type time_to_consume: float + :param time_to_consume: The desired time it should take for that + specific request amount to be consumed in regardless of previously + scheduled consumption requests + + :rtype: float + :returns: The amount of time to wait for the specific request before + actually consuming the specified amount. + """ + self._total_wait += time_to_consume + self._tokens_to_scheduled_consumption[token] = { + 'wait_duration': self._total_wait, + 'time_to_consume': time_to_consume, + } + return self._total_wait + + def process_scheduled_consumption(self, token): + """Processes a scheduled consumption request that has completed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + scheduled_retry = self._tokens_to_scheduled_consumption.pop(token) + self._total_wait = max( + self._total_wait - scheduled_retry['time_to_consume'], 0) + + +class BandwidthRateTracker(object): + def __init__(self, alpha=0.8): + """Tracks the rate of bandwidth consumption + + :type a: float + :param a: The constant to use in calculating the exponentional moving + average of the bandwidth rate. Specifically it is used in the + following calculation: + + current_rate = alpha * new_rate + (1 - alpha) * current_rate + + This value of this constant should be between 0 and 1. + """ + self._alpha = alpha + self._last_time = None + self._current_rate = None + + @property + def current_rate(self): + """The current transfer rate + + :rtype: float + :returns: The current tracked transfer rate + """ + if self._last_time is None: + return 0.0 + return self._current_rate + + def get_projected_rate(self, amt, time_at_consumption): + """Get the projected rate using a provided amount and time + + :type amt: int + :param amt: The proposed amount to consume + + :type time_at_consumption: float + :param time_at_consumption: The proposed time to consume at + + :rtype: float + :returns: The consumption rate if that amt and time were consumed + """ + if self._last_time is None: + return 0.0 + return self._calculate_exponential_moving_average_rate( + amt, time_at_consumption) + + def record_consumption_rate(self, amt, time_at_consumption): + """Record the consumption rate based off amount and time point + + :type amt: int + :param amt: The amount that got consumed + + :type time_at_consumption: float + :param time_at_consumption: The time at which the amount was consumed + """ + if self._last_time is None: + self._last_time = time_at_consumption + self._current_rate = 0.0 + return + self._current_rate = self._calculate_exponential_moving_average_rate( + amt, time_at_consumption) + self._last_time = time_at_consumption + + def _calculate_rate(self, amt, time_at_consumption): + return amt / (time_at_consumption - self._last_time) + + def _calculate_exponential_moving_average_rate(self, amt, + time_at_consumption): + new_rate = self._calculate_rate(amt, time_at_consumption) + return self._alpha * new_rate + (1 - self._alpha) * self._current_rate diff --git a/s3transfer/download.py b/s3transfer/download.py index b60d398c..9b1a49e5 100644 --- a/s3transfer/download.py +++ b/s3transfer/download.py @@ -317,7 +317,7 @@ def _get_download_output_manager_cls(self, transfer_future, osutil): fileobj, type(fileobj))) def _submit(self, client, config, osutil, request_executor, io_executor, - transfer_future): + transfer_future, bandwidth_limiter=None): """ :param client: The client associated with the transfer manager @@ -339,6 +339,10 @@ def _submit(self, client, config, osutil, request_executor, io_executor, :type transfer_future: s3transfer.futures.TransferFuture :param transfer_future: The transfer future associated with the transfer request that tasks are being submitted for + + :type bandwidth_limiter: s3transfer.bandwidth.BandwidthLimiter + :param bandwidth_limiter: The bandwidth limiter to use when + downloading streams """ if transfer_future.meta.size is None: # If a size was not provided figure out the size for the @@ -360,15 +364,16 @@ def _submit(self, client, config, osutil, request_executor, io_executor, if transfer_future.meta.size < config.multipart_threshold: self._submit_download_request( client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future) + download_output_manager, transfer_future, bandwidth_limiter) else: self._submit_ranged_download_request( client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future) + download_output_manager, transfer_future, bandwidth_limiter) def _submit_download_request(self, client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future): + download_output_manager, transfer_future, + bandwidth_limiter): call_args = transfer_future.meta.call_args # Get a handle to the file that will be used for writing downloaded @@ -400,6 +405,7 @@ def _submit_download_request(self, client, config, osutil, 'max_attempts': config.num_download_attempts, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter }, done_callbacks=[final_task] ), @@ -409,7 +415,8 @@ def _submit_download_request(self, client, config, osutil, def _submit_ranged_download_request(self, client, config, osutil, request_executor, io_executor, download_output_manager, - transfer_future): + transfer_future, + bandwidth_limiter): call_args = transfer_future.meta.call_args # Get the needed progress callbacks for the task @@ -461,6 +468,7 @@ def _submit_ranged_download_request(self, client, config, osutil, 'start_index': i * part_size, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter }, done_callbacks=[finalize_download_invoker.decrement] ), @@ -488,7 +496,7 @@ def _calculate_range_param(self, part_size, part_index, num_parts): class GetObjectTask(Task): def _main(self, client, bucket, key, fileobj, extra_args, callbacks, max_attempts, download_output_manager, io_chunksize, - start_index=0): + start_index=0, bandwidth_limiter=None): """Downloads an object and places content into io queue :param client: The client to use when calling GetObject @@ -504,6 +512,8 @@ def _main(self, client, bucket, key, fileobj, extra_args, callbacks, download stream and queue in the io queue. :param start_index: The location in the file to start writing the content of the key to. + :param bandwidth_limiter: The bandwidth limiter to use when throttling + the downloading of data in streams. """ last_exception = None for i in range(max_attempts): @@ -512,6 +522,10 @@ def _main(self, client, bucket, key, fileobj, extra_args, callbacks, Bucket=bucket, Key=key, **extra_args) streaming_body = StreamReaderProgress( response['Body'], callbacks) + if bandwidth_limiter: + streaming_body = \ + bandwidth_limiter.get_bandwith_limited_stream( + streaming_body, self._transfer_coordinator) current_index = start_index chunks = DownloadChunkIterator(streaming_body, io_chunksize) diff --git a/s3transfer/manager.py b/s3transfer/manager.py index 4816f171..63352ae0 100644 --- a/s3transfer/manager.py +++ b/s3transfer/manager.py @@ -17,8 +17,8 @@ from botocore.compat import six from s3transfer.utils import get_callbacks -from s3transfer.utils import disable_upload_callbacks -from s3transfer.utils import enable_upload_callbacks +from s3transfer.utils import signal_transferring +from s3transfer.utils import signal_not_transferring from s3transfer.utils import CallArgs from s3transfer.utils import OSUtils from s3transfer.utils import TaskSemaphore @@ -35,6 +35,8 @@ from s3transfer.upload import UploadSubmissionTask from s3transfer.copies import CopySubmissionTask from s3transfer.delete import DeleteSubmissionTask +from s3transfer.bandwidth import LeakyBucket +from s3transfer.bandwidth import BandwidthLimiter KB = 1024 MB = KB * KB @@ -53,7 +55,8 @@ def __init__(self, io_chunksize=256 * KB, num_download_attempts=5, max_in_memory_upload_chunks=10, - max_in_memory_download_chunks=10): + max_in_memory_download_chunks=10, + max_bandwidth=None): """Configurations for the transfer mangager :param multipart_threshold: The threshold for which multipart @@ -124,6 +127,9 @@ def __init__(self, max_in_memory_download_chunks * multipart_chunksize + :param max_bandwidth: The maximum bandwidth that will be consumed + in uploading and downloading file content. The value is in terms of + bytes per second. """ self.multipart_threshold = multipart_threshold self.multipart_chunksize = multipart_chunksize @@ -136,11 +142,12 @@ def __init__(self, self.num_download_attempts = num_download_attempts self.max_in_memory_upload_chunks = max_in_memory_upload_chunks self.max_in_memory_download_chunks = max_in_memory_download_chunks + self.max_bandwidth = max_bandwidth self._validate_attrs_are_nonzero() def _validate_attrs_are_nonzero(self): for attr, attr_val, in self.__dict__.items(): - if attr_val <= 0: + if attr_val is not None and attr_val <= 0: raise ValueError( 'Provided parameter %s of value %s must be greater than ' '0.' % (attr, attr_val)) @@ -247,6 +254,16 @@ def __init__(self, client, config=None, osutil=None, executor_cls=None): max_num_threads=1, executor_cls=executor_cls ) + + # The component responsible for limiting bandwidth usage if it + # is configured. + self._bandwidth_limiter = None + if self._config.max_bandwidth is not None: + logger.debug( + 'Setting max_bandwidth to %s', self._config.max_bandwidth) + token_bucket = LeakyBucket(self._config.max_bandwidth) + self._bandwidth_limiter = BandwidthLimiter(token_bucket) + self._register_handlers() def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): @@ -284,7 +301,11 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): fileobj=fileobj, bucket=bucket, key=key, extra_args=extra_args, subscribers=subscribers ) - return self._submit_transfer(call_args, UploadSubmissionTask) + extra_main_kwargs = {} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, UploadSubmissionTask, extra_main_kwargs) def download(self, bucket, key, fileobj, extra_args=None, subscribers=None): @@ -320,8 +341,11 @@ def download(self, bucket, key, fileobj, extra_args=None, bucket=bucket, key=key, fileobj=fileobj, extra_args=extra_args, subscribers=subscribers ) - return self._submit_transfer(call_args, DownloadSubmissionTask, - {'io_executor': self._io_executor}) + extra_main_kwargs = {'io_executor': self._io_executor} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, DownloadSubmissionTask, extra_main_kwargs) def copy(self, copy_source, bucket, key, extra_args=None, subscribers=None, source_client=None): @@ -481,11 +505,11 @@ def _register_handlers(self): # Register handlers to enable/disable callbacks on uploads. event_name = 'request-created.s3' self._client.meta.events.register_first( - event_name, disable_upload_callbacks, - unique_id='s3upload-callback-disable') + event_name, signal_not_transferring, + unique_id='s3upload-not-transferring') self._client.meta.events.register_last( - event_name, enable_upload_callbacks, - unique_id='s3upload-callback-enable') + event_name, signal_transferring, + unique_id='s3upload-transferring') def __enter__(self): return self diff --git a/s3transfer/upload.py b/s3transfer/upload.py index d9a9ec5a..0c2feda2 100644 --- a/s3transfer/upload.py +++ b/s3transfer/upload.py @@ -115,9 +115,10 @@ class UploadInputManager(object): that may be accepted. All implementations must subclass and override public methods from this class. """ - def __init__(self, osutil, transfer_coordinator): + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): self._osutil = osutil self._transfer_coordinator = transfer_coordinator + self._bandwidth_limiter = bandwidth_limiter @classmethod def is_compatible(cls, upload_source): @@ -200,8 +201,12 @@ def yield_upload_part_bodies(self, transfer_future, chunksize): """ raise NotImplementedError('must implement yield_upload_part_bodies()') - def _wrap_with_interrupt_reader(self, fileobj): - return InterruptReader(fileobj, self._transfer_coordinator) + def _wrap_fileobj(self, fileobj): + fileobj = InterruptReader(fileobj, self._transfer_coordinator) + if self._bandwidth_limiter: + fileobj = self._bandwidth_limiter.get_bandwith_limited_stream( + fileobj, self._transfer_coordinator, enabled=False) + return fileobj def _get_progress_callbacks(self, transfer_future): callbacks = get_callbacks(transfer_future, 'progress') @@ -241,7 +246,7 @@ def get_put_object_body(self, transfer_future): # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket # to completely read all of the data. - fileobj = self._wrap_with_interrupt_reader(fileobj) + fileobj = self._wrap_fileobj(fileobj) callbacks = self._get_progress_callbacks(transfer_future) close_callbacks = self._get_close_callbacks(callbacks) @@ -268,7 +273,7 @@ def yield_upload_part_bodies(self, transfer_future, chunksize): # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket # to completely read all of the data. - fileobj = self._wrap_with_interrupt_reader(fileobj) + fileobj = self._wrap_fileobj(fileobj) # Wrap the file-like object into a ReadFileChunk to get progress. read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj( @@ -346,9 +351,9 @@ def _get_put_object_fileobj_with_full_size(self, transfer_future): class UploadNonSeekableInputManager(UploadInputManager): """Upload utility for a file-like object that cannot seek.""" - def __init__(self, osutil, transfer_coordinator): + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): super(UploadNonSeekableInputManager, self).__init__( - osutil, transfer_coordinator) + osutil, transfer_coordinator, bandwidth_limiter) self._initial_data = b'' @classmethod @@ -470,7 +475,7 @@ def _wrap_data(self, data, callbacks, close_callbacks): :return: Fully wrapped data. """ - fileobj = self._wrap_with_interrupt_reader(six.BytesIO(data)) + fileobj = self._wrap_fileobj(six.BytesIO(data)) return self._osutil.open_file_chunk_reader_from_fileobj( fileobj=fileobj, chunk_size=len(data), full_file_size=len(data), callbacks=callbacks, close_callbacks=close_callbacks) @@ -511,7 +516,7 @@ def _get_upload_input_manager_cls(self, transfer_future): fileobj, type(fileobj))) def _submit(self, client, config, osutil, request_executor, - transfer_future): + transfer_future, bandwidth_limiter=None): """ :param client: The client associated with the transfer manager @@ -531,7 +536,8 @@ def _submit(self, client, config, osutil, request_executor, transfer request that tasks are being submitted for """ upload_input_manager = self._get_upload_input_manager_cls( - transfer_future)(osutil, self._transfer_coordinator) + transfer_future)( + osutil, self._transfer_coordinator, bandwidth_limiter) # Determine the size if it was not provided if transfer_future.meta.size is None: diff --git a/s3transfer/utils.py b/s3transfer/utils.py index 65b34b88..fba56485 100644 --- a/s3transfer/utils.py +++ b/s3transfer/utils.py @@ -39,16 +39,16 @@ def random_file_extension(num_digits=8): return ''.join(random.choice(string.hexdigits) for _ in range(num_digits)) -def disable_upload_callbacks(request, operation_name, **kwargs): +def signal_not_transferring(request, operation_name, **kwargs): if operation_name in ['PutObject', 'UploadPart'] and \ - hasattr(request.body, 'disable_callback'): - request.body.disable_callback() + hasattr(request.body, 'signal_not_transferring'): + request.body.signal_not_transferring() -def enable_upload_callbacks(request, operation_name, **kwargs): +def signal_transferring(request, operation_name, **kwargs): if operation_name in ['PutObject', 'UploadPart'] and \ - hasattr(request.body, 'enable_callback'): - request.body.enable_callback() + hasattr(request.body, 'signal_transferring'): + request.body.signal_transferring() def calculate_range_parameter(part_size, part_index, num_parts, @@ -433,6 +433,16 @@ def read(self, amount=None): invoke_progress_callbacks(self._callbacks, len(data)) return data + def signal_transferring(self): + self.enable_callback() + if hasattr(self._fileobj, 'signal_transferring'): + self._fileobj.signal_transferring() + + def signal_not_transferring(self): + self.disable_callback() + if hasattr(self._fileobj, 'signal_not_transferring'): + self._fileobj.signal_not_transferring() + def enable_callback(self): self._callbacks_enabled = True diff --git a/tests/functional/test_download.py b/tests/functional/test_download.py index 4d999138..d15d901f 100644 --- a/tests/functional/test_download.py +++ b/tests/functional/test_download.py @@ -13,6 +13,7 @@ import copy import os import tempfile +import time import shutil import glob @@ -381,6 +382,34 @@ def test_download_empty_object(self): with open(self.filename, 'rb') as f: self.assertEqual(b'', f.read()) + def test_uses_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + self.stream = six.BytesIO(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content)/2) + self._manager = TransferManager(self.client, self.config) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + start = time.time() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + class TestRangedDownload(BaseDownloadTest): # TODO: If you want to add tests outside of this test class and still diff --git a/tests/functional/test_manager.py b/tests/functional/test_manager.py index 9aeb9cca..9d5702d3 100644 --- a/tests/functional/test_manager.py +++ b/tests/functional/test_manager.py @@ -27,18 +27,18 @@ class ArbitraryException(Exception): pass -class CallbackEnablingBody(RawIOBase): - """A mocked body with callback enabling/disabling""" +class SignalTransferringBody(RawIOBase): + """A mocked body with the ability to signal when transfers occur""" def __init__(self): - super(CallbackEnablingBody, self).__init__() - self.enable_callback_call_count = 0 - self.disable_callback_call_count = 0 + super(SignalTransferringBody, self).__init__() + self.signal_transferring_call_count = 0 + self.signal_not_transferring_call_count = 0 - def enable_callback(self): - self.enable_callback_call_count += 1 + def signal_transferring(self): + self.signal_transferring_call_count += 1 - def disable_callback(self): - self.disable_callback_call_count += 1 + def signal_not_transferring(self): + self.signal_not_transferring_call_count += 1 def seek(self, where): pass @@ -128,7 +128,7 @@ def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self): future.result() def test_enable_disable_callbacks_only_ever_registered_once(self): - body = CallbackEnablingBody() + body = SignalTransferringBody() request = create_request_object({ 'method': 'PUT', 'url': 'https://s3.amazonaws.com', @@ -145,10 +145,10 @@ def test_enable_disable_callbacks_only_ever_registered_once(self): # handlers registered once depite being used for two different # TransferManagers. self.assertEqual( - body.enable_callback_call_count, 1, + body.signal_transferring_call_count, 1, 'The enable_callback() should have only ever been registered once') self.assertEqual( - body.disable_callback_call_count, 1, + body.signal_not_transferring_call_count, 1, 'The disable_callback() should have only ever been registered ' 'once') diff --git a/tests/functional/test_upload.py b/tests/functional/test_upload.py index e880a8e3..06ee2b0c 100644 --- a/tests/functional/test_upload.py +++ b/tests/functional/test_upload.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import time import tempfile import shutil @@ -90,7 +91,16 @@ def collect_body(self, params, model, **kwargs): 'request-created.s3.%s' % model.name, request=request, operation_name=model.name ) - self.sent_bodies.append(params['Body'].read()) + self.sent_bodies.append(self._stream_body(params['Body'])) + + def _stream_body(self, body): + read_amt = 8 * 1024 + data = body.read(read_amt) + collected_body = data + while data: + data = body.read(read_amt) + collected_body += data + return collected_body @property def manager(self): @@ -231,6 +241,31 @@ def test_allowed_upload_params_are_valid(self): for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS: self.assertIn(allowed_upload_arg, op_model.input_shape.members) + def test_upload_with_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + with open(self.filename, 'wb') as f: + f.write(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content)/2) + self._manager = TransferManager(self.client, self.config) + + self.add_put_object_response_with_default_expected_params() + start = time.time() + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + class TestMultipartUpload(BaseUploadTest): __test__ = True diff --git a/tests/unit/test_bandwidth.py b/tests/unit/test_bandwidth.py new file mode 100644 index 00000000..c6bb1986 --- /dev/null +++ b/tests/unit/test_bandwidth.py @@ -0,0 +1,458 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile + +import mock + +from tests import unittest +from s3transfer.bandwidth import RequestExceededException +from s3transfer.bandwidth import RequestToken +from s3transfer.bandwidth import TimeUtils +from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.bandwidth import BandwidthLimitedStream +from s3transfer.bandwidth import LeakyBucket +from s3transfer.bandwidth import ConsumptionScheduler +from s3transfer.bandwidth import BandwidthRateTracker +from s3transfer.futures import TransferCoordinator + + +class FixedIncrementalTickTimeUtils(TimeUtils): + def __init__(self, seconds_per_tick=1.0): + self._count = 0 + self._seconds_per_tick = seconds_per_tick + + def time(self): + current_count = self._count + self._count += self._seconds_per_tick + return current_count + + +class TestTimeUtils(unittest.TestCase): + @mock.patch('time.time') + def test_time(self, mock_time): + mock_return_val = 1 + mock_time.return_value = mock_return_val + time_utils = TimeUtils() + self.assertEqual(time_utils.time(), mock_return_val) + + @mock.patch('time.sleep') + def test_sleep(self, mock_sleep): + time_utils = TimeUtils() + time_utils.sleep(1) + self.assertEqual( + mock_sleep.call_args_list, + [mock.call(1)] + ) + + +class BaseBandwidthLimitTest(unittest.TestCase): + def setUp(self): + self.leaky_bucket = mock.Mock(LeakyBucket) + self.time_utils = mock.Mock(TimeUtils) + self.tempdir = tempfile.mkdtemp() + self.content = b'a' * 1024 * 1024 + self.filename = os.path.join(self.tempdir, 'myfile') + with open(self.filename, 'wb') as f: + f.write(self.content) + self.coordinator = TransferCoordinator() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def assert_consume_calls(self, amts): + expected_consume_args = [ + mock.call(amt, mock.ANY) for amt in amts + ] + self.assertEqual( + self.leaky_bucket.consume.call_args_list, + expected_consume_args + ) + + +class TestBandwidthLimiter(BaseBandwidthLimitTest): + def setUp(self): + super(TestBandwidthLimiter, self).setUp() + self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) + + def test_get_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.assert_consume_calls(amts=[len(self.content)]) + + def test_get_disabled_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator, enabled=False) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.leaky_bucket.consume.assert_not_called() + + +class TestBandwidthLimitedStream(BaseBandwidthLimitTest): + def setUp(self): + super(TestBandwidthLimitedStream, self).setUp() + self.bytes_threshold = 1 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def get_bandwidth_limited_stream(self, f): + return BandwidthLimitedStream( + f, self.leaky_bucket, self.coordinator, self.time_utils, + self.bytes_threshold) + + def assert_sleep_calls(self, amts): + expected_sleep_args_list = [ + mock.call(amt) for amt in amts + ] + self.assertEqual( + self.time_utils.sleep.call_args_list, + expected_sleep_args_list + ) + + def get_unique_consume_request_tokens(self): + return set( + call_args[0][1] for call_args in + self.leaky_bucket.consume.call_args_list + ) + + def test_read(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[len(self.content)]) + self.assert_sleep_calls(amts=[]) + + def test_retries_on_request_exceeded(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + retry_time = 1 + amt_requested = len(self.content) + self.leaky_bucket.consume.side_effect = [ + RequestExceededException(amt_requested, retry_time), + len(self.content) + ] + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[amt_requested, amt_requested]) + self.assert_sleep_calls(amts=[retry_time]) + + def test_with_transfer_coordinator_exception(self): + self.coordinator.set_exception(ValueError()) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + with self.assertRaises(ValueError): + stream.read(len(self.content)) + + def test_read_when_bandwidth_limiting_disabled(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assertFalse(self.leaky_bucket.consume.called) + + def test_read_toggle_disable_enable_bandwidth_limiting(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.enable_bandwidth_limiting() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + def test_seek(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.seek(1) + self.assertEqual( + mock_fileobj.seek.call_args_list, + [mock.call(1)] + ) + + def test_tell(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.tell() + self.assertEqual( + mock_fileobj.tell.call_args_list, + [mock.call()] + ) + + def test_close(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.close() + self.assertEqual( + mock_fileobj.close.call_args_list, + [mock.call()] + ) + + def test_context_manager(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + with stream as stream_handle: + self.assertIs(stream_handle, stream) + self.assertEqual( + mock_fileobj.close.call_args_list, + [mock.call()] + ) + + def test_reuses_request_token(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) + + def test_request_tokens_unique_per_stream(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) + + def test_call_consume_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + self.assertEqual(stream.read(1), self.content[1:2]) + self.assert_consume_calls(amts=[2]) + + def test_resets_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + self.assertEqual(stream.read(1), self.content[2:3]) + self.assert_consume_calls(amts=[2]) + + def test_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.close() + self.assert_consume_calls(amts=[1]) + + def test_no_bytes_remaining_on(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + stream.close() + # There should have been no more consume() calls made + # as all bytes have been accounted for in the previous + # consume() call. + self.assert_consume_calls(amts=[2]) + + def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.disable_bandwidth_limiting() + stream.close() + self.assert_consume_calls(amts=[]) + + def test_signal_transferring(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.signal_not_transferring() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.signal_transferring() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + +class TestLeakyBucket(unittest.TestCase): + def setUp(self): + self.max_rate = 1 + self.time_now = 1.0 + self.time_utils = mock.Mock(TimeUtils) + self.time_utils.time.return_value = self.time_now + self.scheduler = mock.Mock(ConsumptionScheduler) + self.scheduler.is_scheduled.return_value = False + self.rate_tracker = mock.Mock(BandwidthRateTracker) + self.leaky_bucket = LeakyBucket( + self.max_rate, self.time_utils, self.rate_tracker, + self.scheduler + ) + + def set_projected_rate(self, rate): + self.rate_tracker.get_projected_rate.return_value = rate + + def set_retry_time(self, retry_time): + self.scheduler.schedule_consumption.return_value = retry_time + + def assert_recorded_consumed_amt(self, expected_amt): + self.assertEqual( + self.rate_tracker.record_consumption_rate.call_args, + mock.call(expected_amt, self.time_utils.time.return_value)) + + def assert_was_scheduled(self, amt, token): + self.assertEqual( + self.scheduler.schedule_consumption.call_args, + mock.call(amt, token, amt/(self.max_rate)) + ) + + def assert_nothing_scheduled(self): + self.assertFalse(self.scheduler.schedule_consumption.called) + + def assert_processed_request_token(self, request_token): + self.assertEqual( + self.scheduler.process_scheduled_consumption.call_args, + mock.call(request_token) + ) + + def test_consume_under_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate/2) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_at_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_over_max_rate(self): + amt = 1 + retry_time = 2.0 + self.set_projected_rate(self.max_rate + 1) + self.set_retry_time(retry_time) + request_token = RequestToken() + try: + self.leaky_bucket.consume(amt, request_token) + self.fail('A RequestExceededException should have been thrown') + except RequestExceededException as e: + self.assertEqual(e.requested_amt, amt) + self.assertEqual(e.retry_time, retry_time) + self.assert_was_scheduled(amt, request_token) + + def test_consume_with_scheduled_retry(self): + amt = 1 + self.set_projected_rate(self.max_rate + 1) + self.scheduler.is_scheduled.return_value = True + request_token = RequestToken() + self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) + # Nothing new should have been scheduled but the request token + # should have been processed. + self.assert_nothing_scheduled() + self.assert_processed_request_token(request_token) + + +class TestConsumptionScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = ConsumptionScheduler() + + def test_schedule_consumption(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time) + self.assertEqual(consume_time, actual_wait_time) + + def test_schedule_consumption_for_multiple_requests(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time) + self.assertEqual(consume_time, actual_wait_time) + + other_consume_time = 3 + other_token = RequestToken() + next_wait_time = self.scheduler.schedule_consumption( + 1, other_token, other_consume_time) + + # This wait time should be the previous time plus its desired + # wait time + self.assertEqual(next_wait_time, consume_time + other_consume_time) + + def test_is_scheduled(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.assertTrue(self.scheduler.is_scheduled(token)) + + def test_is_not_scheduled(self): + self.assertFalse(self.scheduler.is_scheduled(RequestToken())) + + def test_process_scheduled_consumption(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.scheduler.process_scheduled_consumption(token) + self.assertFalse(self.scheduler.is_scheduled(token)) + different_time = 7 + # The previous consume time should have no affect on the next wait tim + # as it has been completed. + self.assertEqual( + self.scheduler.schedule_consumption(1, token, different_time), + different_time + ) + + +class TestBandwidthRateTracker(unittest.TestCase): + def setUp(self): + self.alpha = 0.8 + self.rate_tracker = BandwidthRateTracker(self.alpha) + + def test_current_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate_after_one_recorded_point(self): + self.rate_tracker.record_consumption_rate(1, 1) + # There is no last time point to do a diff against so return a + # current rate of 0.0 + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, 0.96) + + def test_get_projected_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) + + def test_get_projected_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + projected_rate = self.rate_tracker.get_projected_rate(1, 3) + self.assertEqual(projected_rate, 0.96) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, projected_rate) diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index bc759d65..b2d9fe55 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -26,6 +26,7 @@ from s3transfer.compat import six from s3transfer.compat import SOCKET_ERROR from s3transfer.exceptions import RetriesExceededError +from s3transfer.bandwidth import BandwidthLimiter from s3transfer.download import DownloadFilenameOutputManager from s3transfer.download import DownloadSpecialFilenameOutputManager from s3transfer.download import DownloadSeekableOutputManager @@ -611,6 +612,22 @@ def test_start_index(self): self.stubber.assert_no_pending_responses() self.assert_io_writes([(5, self.content)]) + def test_uses_bandwidth_limiter(self): + bandwidth_limiter = mock.Mock(BandwidthLimiter) + + self.stubber.add_response( + 'get_object', service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key} + ) + task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual( + bandwidth_limiter.get_bandwith_limited_stream.call_args_list, + [mock.call(mock.ANY, self.transfer_coordinator)] + ) + def test_retries_succeeds(self): self.stubber.add_response( 'get_object', service_response={ diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d431b1d7..f76b2699 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -18,6 +18,8 @@ import time import io +import mock + from tests import unittest from tests import RecordingSubscriber from tests import NonSeekableWriter @@ -511,6 +513,55 @@ def test_close_callbacks_when_context_handler_is_used(self): chunk.read(1) self.assertEqual(self.num_close_callback_calls, 1) + def test_signal_transferring(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, + callbacks=[self.callback]) + chunk.signal_not_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, []) + chunk.signal_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, [1]) + + def test_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_transferring() + self.assertTrue(underlying_stream.signal_transferring.called) + + def test_no_call_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call signal_transferring ' + 'to the underlying stream.' + ) + + def test_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_not_transferring() + self.assertTrue(underlying_stream.signal_not_transferring.called) + + def test_no_call_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_not_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call ' + 'signal_not_transferring to the underlying stream.' + ) + class TestStreamReaderProgress(BaseUtilsTest): def test_proxies_to_wrapped_stream(self): From 7409490403d3e5a53fb389f91ca4a0195c095af3 Mon Sep 17 00:00:00 2001 From: kyleknap Date: Tue, 28 Nov 2017 22:15:43 -0800 Subject: [PATCH 2/3] Switch references to leaky bucket --- s3transfer/bandwidth.py | 12 ++++++------ s3transfer/manager.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/s3transfer/bandwidth.py b/s3transfer/bandwidth.py index cb693e28..3ac9b340 100644 --- a/s3transfer/bandwidth.py +++ b/s3transfer/bandwidth.py @@ -61,16 +61,16 @@ def sleep(self, value): class BandwidthLimiter(object): - def __init__(self, token_bucket, time_utils=None): + def __init__(self, leaky_bucket, time_utils=None): """Limits bandwidth for shared S3 transfers :type leaky_bucket: LeakyBucket - :param token_bucket: The leaky bucket to use limit bandwidth + :param leaky_bucket: The leaky bucket to use limit bandwidth :type time_utils: TimeUtils :param time_utils: Time utility to use for interacting with time. """ - self._token_bucket = token_bucket + self._leaky_bucket = leaky_bucket self._time_utils = time_utils if time_utils is None: self._time_utils = TimeUtils() @@ -90,7 +90,7 @@ def get_bandwith_limited_stream(self, fileobj, transfer_coordinator, :param enabled: Whether bandwidth limiting should be enabled to start """ stream = BandwidthLimitedStream( - fileobj, self._token_bucket, transfer_coordinator, + fileobj, self._leaky_bucket, transfer_coordinator, self._time_utils) if not enabled: stream.disable_bandwidth_limiting() @@ -105,8 +105,8 @@ def __init__(self, fileobj, leaky_bucket, transfer_coordinator, :type fileobj: file-like object :param fileobj: The file like object to wrap - :type token_bucket: LeakyBucket - :param token_bucket: The leaky bucket to use to throttle reads on + :type leaky_bucket: LeakyBucket + :param leaky_bucket: The leaky bucket to use to throttle reads on the stream :type transfer_coordinator: s3transfer.futures.TransferCoordinator diff --git a/s3transfer/manager.py b/s3transfer/manager.py index 63352ae0..fb7a9cbb 100644 --- a/s3transfer/manager.py +++ b/s3transfer/manager.py @@ -261,8 +261,8 @@ def __init__(self, client, config=None, osutil=None, executor_cls=None): if self._config.max_bandwidth is not None: logger.debug( 'Setting max_bandwidth to %s', self._config.max_bandwidth) - token_bucket = LeakyBucket(self._config.max_bandwidth) - self._bandwidth_limiter = BandwidthLimiter(token_bucket) + leaky_bucket = LeakyBucket(self._config.max_bandwidth) + self._bandwidth_limiter = BandwidthLimiter(leaky_bucket) self._register_handlers() From 94b402dc2981dc73e77162967924c1243dcde09c Mon Sep 17 00:00:00 2001 From: kyleknap Date: Wed, 29 Nov 2017 05:41:47 -0800 Subject: [PATCH 3/3] Add case for calculating rate of zero time delta This was not seen in real life transfers, but for the functional tests, the time deltas between reads were infitesimally small causing the rate to be calculate with a time of 0 and throw a ZeroDivisionError --- s3transfer/bandwidth.py | 9 ++++++++- tests/unit/test_bandwidth.py | 7 +++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/s3transfer/bandwidth.py b/s3transfer/bandwidth.py index 3ac9b340..8b3f6f50 100644 --- a/s3transfer/bandwidth.py +++ b/s3transfer/bandwidth.py @@ -401,7 +401,14 @@ def record_consumption_rate(self, amt, time_at_consumption): self._last_time = time_at_consumption def _calculate_rate(self, amt, time_at_consumption): - return amt / (time_at_consumption - self._last_time) + time_delta = time_at_consumption - self._last_time + if time_delta <= 0: + # While it is really unlikley to see this in an actual transfer, + # we do not want to be returning back a negative rate or try to + # divide the amount by zero. So instead return back an infinite + # rate as the time delta is infinitesimally small. + return float('inf') + return amt / (time_delta) def _calculate_exponential_moving_average_rate(self, amt, time_at_consumption): diff --git a/tests/unit/test_bandwidth.py b/tests/unit/test_bandwidth.py index c6bb1986..4f35a7a8 100644 --- a/tests/unit/test_bandwidth.py +++ b/tests/unit/test_bandwidth.py @@ -456,3 +456,10 @@ def test_get_projected_rate(self): self.assertEqual(projected_rate, 0.96) self.rate_tracker.record_consumption_rate(1, 3) self.assertEqual(self.rate_tracker.current_rate, projected_rate) + + def test_get_projected_rate_for_same_timestamp(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.assertEqual( + self.rate_tracker.get_projected_rate(1, 1), + float('inf') + )