diff --git a/.changes/next-release/enhancement-maxbandwidth-87115.json b/.changes/next-release/enhancement-maxbandwidth-87115.json new file mode 100644 index 00000000..f730339e --- /dev/null +++ b/.changes/next-release/enhancement-maxbandwidth-87115.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``max_bandwidth``", + "description": "Add ability to set maximum bandwidth consumption for streaming of S3 uploads and downloads" +} diff --git a/s3transfer/bandwidth.py b/s3transfer/bandwidth.py new file mode 100644 index 00000000..8b3f6f50 --- /dev/null +++ b/s3transfer/bandwidth.py @@ -0,0 +1,416 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import time +import threading + + +class RequestExceededException(Exception): + def __init__(self, requested_amt, retry_time): + """Error when requested amount exceeds what is allowed + + The request that raised this error should be retried after waiting + the time specified by ``retry_time``. + + :type requested_amt: int + :param requested_amt: The originally requested byte amount + + :type retry_time: float + :param retry_time: The length in time to wait to retry for the + requested amount + """ + self.requested_amt = requested_amt + self.retry_time = retry_time + msg = ( + 'Request amount %s exceeded the amount available. Retry in %s' % ( + requested_amt, retry_time) + ) + super(RequestExceededException, self).__init__(msg) + + +class RequestToken(object): + """A token to pass as an identifier when consuming from the LeakyBucket""" + pass + + +class TimeUtils(object): + def time(self): + """Get the current time back + + :rtype: float + :returns: The current time in seconds + """ + return time.time() + + def sleep(self, value): + """Sleep for a designated time + + :type value: float + :param value: The time to sleep for in seconds + """ + return time.sleep(value) + + +class BandwidthLimiter(object): + def __init__(self, leaky_bucket, time_utils=None): + """Limits bandwidth for shared S3 transfers + + :type leaky_bucket: LeakyBucket + :param leaky_bucket: The leaky bucket to use limit bandwidth + + :type time_utils: TimeUtils + :param time_utils: Time utility to use for interacting with time. + """ + self._leaky_bucket = leaky_bucket + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + + def get_bandwith_limited_stream(self, fileobj, transfer_coordinator, + enabled=True): + """Wraps a fileobj in a bandwidth limited stream wrapper + + :type fileobj: file-like obj + :param fileobj: The file-like obj to wrap + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type enabled: boolean + :param enabled: Whether bandwidth limiting should be enabled to start + """ + stream = BandwidthLimitedStream( + fileobj, self._leaky_bucket, transfer_coordinator, + self._time_utils) + if not enabled: + stream.disable_bandwidth_limiting() + return stream + + +class BandwidthLimitedStream(object): + def __init__(self, fileobj, leaky_bucket, transfer_coordinator, + time_utils=None, bytes_threshold=256 * 1024): + """Limits bandwidth for reads on a wrapped stream + + :type fileobj: file-like object + :param fileobj: The file like object to wrap + + :type leaky_bucket: LeakyBucket + :param leaky_bucket: The leaky bucket to use to throttle reads on + the stream + + :type transfer_coordinator: s3transfer.futures.TransferCoordinator + param transfer_coordinator: The coordinator for the general transfer + that the wrapped stream is a part of + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + """ + self._fileobj = fileobj + self._leaky_bucket = leaky_bucket + self._transfer_coordinator = transfer_coordinator + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._bandwidth_limiting_enabled = True + self._request_token = RequestToken() + self._bytes_seen = 0 + self._bytes_threshold = bytes_threshold + + def enable_bandwidth_limiting(self): + """Enable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = True + + def disable_bandwidth_limiting(self): + """Disable bandwidth limiting on reads to the stream""" + self._bandwidth_limiting_enabled = False + + def read(self, amount): + """Read a specified amount + + Reads will only be throttled if bandwidth limiting is enabled. + """ + if not self._bandwidth_limiting_enabled: + return self._fileobj.read(amount) + + # We do not want to be calling consume on every read as the read + # amounts can be small causing the lock of the leaky bucket to + # introduce noticeable overhead. So instead we keep track of + # how many bytes we have seen and only call consume once we pass a + # certain threshold. + self._bytes_seen += amount + if self._bytes_seen < self._bytes_threshold: + return self._fileobj.read(amount) + + self._consume_through_leaky_bucket() + return self._fileobj.read(amount) + + def _consume_through_leaky_bucket(self): + # NOTE: If the read amonut on the stream are high, it will result + # in large bursty behavior as there is not an interface for partial + # reads. However given the read's on this abstraction are at most 256KB + # (via downloads), it reduces the burstiness to be small KB bursts at + # worst. + while not self._transfer_coordinator.exception: + try: + self._leaky_bucket.consume( + self._bytes_seen, self._request_token) + self._bytes_seen = 0 + return + except RequestExceededException as e: + self._time_utils.sleep(e.retry_time) + else: + raise self._transfer_coordinator.exception + + def signal_transferring(self): + """Signal that data being read is being transferred to S3""" + self.enable_bandwidth_limiting() + + def signal_not_transferring(self): + """Signal that data being read is not being transferred to S3""" + self.disable_bandwidth_limiting() + + def seek(self, where): + self._fileobj.seek(where) + + def tell(self): + return self._fileobj.tell() + + def close(self): + if self._bandwidth_limiting_enabled and self._bytes_seen: + # This handles the case where the file is small enough to never + # trigger the threshold and thus is never subjugated to the + # leaky bucket on read(). This specifically happens for small + # uploads. So instead to account for those bytes, have + # it go through the leaky bucket when the file gets closed. + self._consume_through_leaky_bucket() + self._fileobj.close() + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.close() + + +class LeakyBucket(object): + def __init__(self, max_rate, time_utils=None, rate_tracker=None, + consumption_scheduler=None): + """A leaky bucket abstraction to limit bandwidth consumption + + :type rate: int + :type rate: The maximum rate to allow. This rate is in terms of + bytes per second. + + :type time_utils: TimeUtils + :param time_utils: The time utility to use for interacting with time + + :type rate_tracker: BandwidthRateTracker + :param rate_tracker: Tracks bandwidth consumption + + :type consumption_scheduler: ConsumptionScheduler + :param consumption_scheduler: Schedules consumption retries when + necessary + """ + self._max_rate = float(max_rate) + self._time_utils = time_utils + if time_utils is None: + self._time_utils = TimeUtils() + self._lock = threading.Lock() + self._rate_tracker = rate_tracker + if rate_tracker is None: + self._rate_tracker = BandwidthRateTracker() + self._consumption_scheduler = consumption_scheduler + if consumption_scheduler is None: + self._consumption_scheduler = ConsumptionScheduler() + + def consume(self, amt, request_token): + """Consume an a requested amount + + :type amt: int + :param amt: The amount of bytes to request to consume + + :type request_token: RequestToken + :param request_token: The token associated to the consumption + request that is used to identify the request. So if a + RequestExceededException is raised the token should be used + in subsequent retry consume() request. + + :raises RequestExceededException: If the consumption amount would + exceed the maximum allocated bandwidth + + :rtype: int + :returns: The amount consumed + """ + with self._lock: + time_now = self._time_utils.time() + if self._consumption_scheduler.is_scheduled(request_token): + return self._release_requested_amt_for_scheduled_request( + amt, request_token, time_now) + elif self._projected_to_exceed_max_rate(amt, time_now): + self._raise_request_exceeded_exception( + amt, request_token, time_now) + else: + return self._release_requested_amt(amt, time_now) + + def _projected_to_exceed_max_rate(self, amt, time_now): + projected_rate = self._rate_tracker.get_projected_rate(amt, time_now) + return projected_rate > self._max_rate + + def _release_requested_amt_for_scheduled_request(self, amt, request_token, + time_now): + self._consumption_scheduler.process_scheduled_consumption( + request_token) + return self._release_requested_amt(amt, time_now) + + def _raise_request_exceeded_exception(self, amt, request_token, time_now): + allocated_time = amt/float(self._max_rate) + retry_time = self._consumption_scheduler.schedule_consumption( + amt, request_token, allocated_time) + raise RequestExceededException( + requested_amt=amt, retry_time=retry_time) + + def _release_requested_amt(self, amt, time_now): + self._rate_tracker.record_consumption_rate(amt, time_now) + return amt + + +class ConsumptionScheduler(object): + def __init__(self): + """Schedules when to consume a desired amount""" + self._tokens_to_scheduled_consumption = {} + self._total_wait = 0 + + def is_scheduled(self, token): + """Indicates if a consumption request has been scheduled + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + return token in self._tokens_to_scheduled_consumption + + def schedule_consumption(self, amt, token, time_to_consume): + """Schedules a wait time to be able to consume an amount + + :type amt: int + :param amt: The amount of bytes scheduled to be consumed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + + :type time_to_consume: float + :param time_to_consume: The desired time it should take for that + specific request amount to be consumed in regardless of previously + scheduled consumption requests + + :rtype: float + :returns: The amount of time to wait for the specific request before + actually consuming the specified amount. + """ + self._total_wait += time_to_consume + self._tokens_to_scheduled_consumption[token] = { + 'wait_duration': self._total_wait, + 'time_to_consume': time_to_consume, + } + return self._total_wait + + def process_scheduled_consumption(self, token): + """Processes a scheduled consumption request that has completed + + :type token: RequestToken + :param token: The token associated to the consumption + request that is used to identify the request. + """ + scheduled_retry = self._tokens_to_scheduled_consumption.pop(token) + self._total_wait = max( + self._total_wait - scheduled_retry['time_to_consume'], 0) + + +class BandwidthRateTracker(object): + def __init__(self, alpha=0.8): + """Tracks the rate of bandwidth consumption + + :type a: float + :param a: The constant to use in calculating the exponentional moving + average of the bandwidth rate. Specifically it is used in the + following calculation: + + current_rate = alpha * new_rate + (1 - alpha) * current_rate + + This value of this constant should be between 0 and 1. + """ + self._alpha = alpha + self._last_time = None + self._current_rate = None + + @property + def current_rate(self): + """The current transfer rate + + :rtype: float + :returns: The current tracked transfer rate + """ + if self._last_time is None: + return 0.0 + return self._current_rate + + def get_projected_rate(self, amt, time_at_consumption): + """Get the projected rate using a provided amount and time + + :type amt: int + :param amt: The proposed amount to consume + + :type time_at_consumption: float + :param time_at_consumption: The proposed time to consume at + + :rtype: float + :returns: The consumption rate if that amt and time were consumed + """ + if self._last_time is None: + return 0.0 + return self._calculate_exponential_moving_average_rate( + amt, time_at_consumption) + + def record_consumption_rate(self, amt, time_at_consumption): + """Record the consumption rate based off amount and time point + + :type amt: int + :param amt: The amount that got consumed + + :type time_at_consumption: float + :param time_at_consumption: The time at which the amount was consumed + """ + if self._last_time is None: + self._last_time = time_at_consumption + self._current_rate = 0.0 + return + self._current_rate = self._calculate_exponential_moving_average_rate( + amt, time_at_consumption) + self._last_time = time_at_consumption + + def _calculate_rate(self, amt, time_at_consumption): + time_delta = time_at_consumption - self._last_time + if time_delta <= 0: + # While it is really unlikley to see this in an actual transfer, + # we do not want to be returning back a negative rate or try to + # divide the amount by zero. So instead return back an infinite + # rate as the time delta is infinitesimally small. + return float('inf') + return amt / (time_delta) + + def _calculate_exponential_moving_average_rate(self, amt, + time_at_consumption): + new_rate = self._calculate_rate(amt, time_at_consumption) + return self._alpha * new_rate + (1 - self._alpha) * self._current_rate diff --git a/s3transfer/download.py b/s3transfer/download.py index b60d398c..9b1a49e5 100644 --- a/s3transfer/download.py +++ b/s3transfer/download.py @@ -317,7 +317,7 @@ def _get_download_output_manager_cls(self, transfer_future, osutil): fileobj, type(fileobj))) def _submit(self, client, config, osutil, request_executor, io_executor, - transfer_future): + transfer_future, bandwidth_limiter=None): """ :param client: The client associated with the transfer manager @@ -339,6 +339,10 @@ def _submit(self, client, config, osutil, request_executor, io_executor, :type transfer_future: s3transfer.futures.TransferFuture :param transfer_future: The transfer future associated with the transfer request that tasks are being submitted for + + :type bandwidth_limiter: s3transfer.bandwidth.BandwidthLimiter + :param bandwidth_limiter: The bandwidth limiter to use when + downloading streams """ if transfer_future.meta.size is None: # If a size was not provided figure out the size for the @@ -360,15 +364,16 @@ def _submit(self, client, config, osutil, request_executor, io_executor, if transfer_future.meta.size < config.multipart_threshold: self._submit_download_request( client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future) + download_output_manager, transfer_future, bandwidth_limiter) else: self._submit_ranged_download_request( client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future) + download_output_manager, transfer_future, bandwidth_limiter) def _submit_download_request(self, client, config, osutil, request_executor, io_executor, - download_output_manager, transfer_future): + download_output_manager, transfer_future, + bandwidth_limiter): call_args = transfer_future.meta.call_args # Get a handle to the file that will be used for writing downloaded @@ -400,6 +405,7 @@ def _submit_download_request(self, client, config, osutil, 'max_attempts': config.num_download_attempts, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter }, done_callbacks=[final_task] ), @@ -409,7 +415,8 @@ def _submit_download_request(self, client, config, osutil, def _submit_ranged_download_request(self, client, config, osutil, request_executor, io_executor, download_output_manager, - transfer_future): + transfer_future, + bandwidth_limiter): call_args = transfer_future.meta.call_args # Get the needed progress callbacks for the task @@ -461,6 +468,7 @@ def _submit_ranged_download_request(self, client, config, osutil, 'start_index': i * part_size, 'download_output_manager': download_output_manager, 'io_chunksize': config.io_chunksize, + 'bandwidth_limiter': bandwidth_limiter }, done_callbacks=[finalize_download_invoker.decrement] ), @@ -488,7 +496,7 @@ def _calculate_range_param(self, part_size, part_index, num_parts): class GetObjectTask(Task): def _main(self, client, bucket, key, fileobj, extra_args, callbacks, max_attempts, download_output_manager, io_chunksize, - start_index=0): + start_index=0, bandwidth_limiter=None): """Downloads an object and places content into io queue :param client: The client to use when calling GetObject @@ -504,6 +512,8 @@ def _main(self, client, bucket, key, fileobj, extra_args, callbacks, download stream and queue in the io queue. :param start_index: The location in the file to start writing the content of the key to. + :param bandwidth_limiter: The bandwidth limiter to use when throttling + the downloading of data in streams. """ last_exception = None for i in range(max_attempts): @@ -512,6 +522,10 @@ def _main(self, client, bucket, key, fileobj, extra_args, callbacks, Bucket=bucket, Key=key, **extra_args) streaming_body = StreamReaderProgress( response['Body'], callbacks) + if bandwidth_limiter: + streaming_body = \ + bandwidth_limiter.get_bandwith_limited_stream( + streaming_body, self._transfer_coordinator) current_index = start_index chunks = DownloadChunkIterator(streaming_body, io_chunksize) diff --git a/s3transfer/manager.py b/s3transfer/manager.py index 4816f171..fb7a9cbb 100644 --- a/s3transfer/manager.py +++ b/s3transfer/manager.py @@ -17,8 +17,8 @@ from botocore.compat import six from s3transfer.utils import get_callbacks -from s3transfer.utils import disable_upload_callbacks -from s3transfer.utils import enable_upload_callbacks +from s3transfer.utils import signal_transferring +from s3transfer.utils import signal_not_transferring from s3transfer.utils import CallArgs from s3transfer.utils import OSUtils from s3transfer.utils import TaskSemaphore @@ -35,6 +35,8 @@ from s3transfer.upload import UploadSubmissionTask from s3transfer.copies import CopySubmissionTask from s3transfer.delete import DeleteSubmissionTask +from s3transfer.bandwidth import LeakyBucket +from s3transfer.bandwidth import BandwidthLimiter KB = 1024 MB = KB * KB @@ -53,7 +55,8 @@ def __init__(self, io_chunksize=256 * KB, num_download_attempts=5, max_in_memory_upload_chunks=10, - max_in_memory_download_chunks=10): + max_in_memory_download_chunks=10, + max_bandwidth=None): """Configurations for the transfer mangager :param multipart_threshold: The threshold for which multipart @@ -124,6 +127,9 @@ def __init__(self, max_in_memory_download_chunks * multipart_chunksize + :param max_bandwidth: The maximum bandwidth that will be consumed + in uploading and downloading file content. The value is in terms of + bytes per second. """ self.multipart_threshold = multipart_threshold self.multipart_chunksize = multipart_chunksize @@ -136,11 +142,12 @@ def __init__(self, self.num_download_attempts = num_download_attempts self.max_in_memory_upload_chunks = max_in_memory_upload_chunks self.max_in_memory_download_chunks = max_in_memory_download_chunks + self.max_bandwidth = max_bandwidth self._validate_attrs_are_nonzero() def _validate_attrs_are_nonzero(self): for attr, attr_val, in self.__dict__.items(): - if attr_val <= 0: + if attr_val is not None and attr_val <= 0: raise ValueError( 'Provided parameter %s of value %s must be greater than ' '0.' % (attr, attr_val)) @@ -247,6 +254,16 @@ def __init__(self, client, config=None, osutil=None, executor_cls=None): max_num_threads=1, executor_cls=executor_cls ) + + # The component responsible for limiting bandwidth usage if it + # is configured. + self._bandwidth_limiter = None + if self._config.max_bandwidth is not None: + logger.debug( + 'Setting max_bandwidth to %s', self._config.max_bandwidth) + leaky_bucket = LeakyBucket(self._config.max_bandwidth) + self._bandwidth_limiter = BandwidthLimiter(leaky_bucket) + self._register_handlers() def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): @@ -284,7 +301,11 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): fileobj=fileobj, bucket=bucket, key=key, extra_args=extra_args, subscribers=subscribers ) - return self._submit_transfer(call_args, UploadSubmissionTask) + extra_main_kwargs = {} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, UploadSubmissionTask, extra_main_kwargs) def download(self, bucket, key, fileobj, extra_args=None, subscribers=None): @@ -320,8 +341,11 @@ def download(self, bucket, key, fileobj, extra_args=None, bucket=bucket, key=key, fileobj=fileobj, extra_args=extra_args, subscribers=subscribers ) - return self._submit_transfer(call_args, DownloadSubmissionTask, - {'io_executor': self._io_executor}) + extra_main_kwargs = {'io_executor': self._io_executor} + if self._bandwidth_limiter: + extra_main_kwargs['bandwidth_limiter'] = self._bandwidth_limiter + return self._submit_transfer( + call_args, DownloadSubmissionTask, extra_main_kwargs) def copy(self, copy_source, bucket, key, extra_args=None, subscribers=None, source_client=None): @@ -481,11 +505,11 @@ def _register_handlers(self): # Register handlers to enable/disable callbacks on uploads. event_name = 'request-created.s3' self._client.meta.events.register_first( - event_name, disable_upload_callbacks, - unique_id='s3upload-callback-disable') + event_name, signal_not_transferring, + unique_id='s3upload-not-transferring') self._client.meta.events.register_last( - event_name, enable_upload_callbacks, - unique_id='s3upload-callback-enable') + event_name, signal_transferring, + unique_id='s3upload-transferring') def __enter__(self): return self diff --git a/s3transfer/upload.py b/s3transfer/upload.py index d9a9ec5a..0c2feda2 100644 --- a/s3transfer/upload.py +++ b/s3transfer/upload.py @@ -115,9 +115,10 @@ class UploadInputManager(object): that may be accepted. All implementations must subclass and override public methods from this class. """ - def __init__(self, osutil, transfer_coordinator): + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): self._osutil = osutil self._transfer_coordinator = transfer_coordinator + self._bandwidth_limiter = bandwidth_limiter @classmethod def is_compatible(cls, upload_source): @@ -200,8 +201,12 @@ def yield_upload_part_bodies(self, transfer_future, chunksize): """ raise NotImplementedError('must implement yield_upload_part_bodies()') - def _wrap_with_interrupt_reader(self, fileobj): - return InterruptReader(fileobj, self._transfer_coordinator) + def _wrap_fileobj(self, fileobj): + fileobj = InterruptReader(fileobj, self._transfer_coordinator) + if self._bandwidth_limiter: + fileobj = self._bandwidth_limiter.get_bandwith_limited_stream( + fileobj, self._transfer_coordinator, enabled=False) + return fileobj def _get_progress_callbacks(self, transfer_future): callbacks = get_callbacks(transfer_future, 'progress') @@ -241,7 +246,7 @@ def get_put_object_body(self, transfer_future): # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket # to completely read all of the data. - fileobj = self._wrap_with_interrupt_reader(fileobj) + fileobj = self._wrap_fileobj(fileobj) callbacks = self._get_progress_callbacks(transfer_future) close_callbacks = self._get_close_callbacks(callbacks) @@ -268,7 +273,7 @@ def yield_upload_part_bodies(self, transfer_future, chunksize): # Wrap fileobj with interrupt reader that will quickly cancel # uploads if needed instead of having to wait for the socket # to completely read all of the data. - fileobj = self._wrap_with_interrupt_reader(fileobj) + fileobj = self._wrap_fileobj(fileobj) # Wrap the file-like object into a ReadFileChunk to get progress. read_file_chunk = self._osutil.open_file_chunk_reader_from_fileobj( @@ -346,9 +351,9 @@ def _get_put_object_fileobj_with_full_size(self, transfer_future): class UploadNonSeekableInputManager(UploadInputManager): """Upload utility for a file-like object that cannot seek.""" - def __init__(self, osutil, transfer_coordinator): + def __init__(self, osutil, transfer_coordinator, bandwidth_limiter=None): super(UploadNonSeekableInputManager, self).__init__( - osutil, transfer_coordinator) + osutil, transfer_coordinator, bandwidth_limiter) self._initial_data = b'' @classmethod @@ -470,7 +475,7 @@ def _wrap_data(self, data, callbacks, close_callbacks): :return: Fully wrapped data. """ - fileobj = self._wrap_with_interrupt_reader(six.BytesIO(data)) + fileobj = self._wrap_fileobj(six.BytesIO(data)) return self._osutil.open_file_chunk_reader_from_fileobj( fileobj=fileobj, chunk_size=len(data), full_file_size=len(data), callbacks=callbacks, close_callbacks=close_callbacks) @@ -511,7 +516,7 @@ def _get_upload_input_manager_cls(self, transfer_future): fileobj, type(fileobj))) def _submit(self, client, config, osutil, request_executor, - transfer_future): + transfer_future, bandwidth_limiter=None): """ :param client: The client associated with the transfer manager @@ -531,7 +536,8 @@ def _submit(self, client, config, osutil, request_executor, transfer request that tasks are being submitted for """ upload_input_manager = self._get_upload_input_manager_cls( - transfer_future)(osutil, self._transfer_coordinator) + transfer_future)( + osutil, self._transfer_coordinator, bandwidth_limiter) # Determine the size if it was not provided if transfer_future.meta.size is None: diff --git a/s3transfer/utils.py b/s3transfer/utils.py index 65b34b88..fba56485 100644 --- a/s3transfer/utils.py +++ b/s3transfer/utils.py @@ -39,16 +39,16 @@ def random_file_extension(num_digits=8): return ''.join(random.choice(string.hexdigits) for _ in range(num_digits)) -def disable_upload_callbacks(request, operation_name, **kwargs): +def signal_not_transferring(request, operation_name, **kwargs): if operation_name in ['PutObject', 'UploadPart'] and \ - hasattr(request.body, 'disable_callback'): - request.body.disable_callback() + hasattr(request.body, 'signal_not_transferring'): + request.body.signal_not_transferring() -def enable_upload_callbacks(request, operation_name, **kwargs): +def signal_transferring(request, operation_name, **kwargs): if operation_name in ['PutObject', 'UploadPart'] and \ - hasattr(request.body, 'enable_callback'): - request.body.enable_callback() + hasattr(request.body, 'signal_transferring'): + request.body.signal_transferring() def calculate_range_parameter(part_size, part_index, num_parts, @@ -433,6 +433,16 @@ def read(self, amount=None): invoke_progress_callbacks(self._callbacks, len(data)) return data + def signal_transferring(self): + self.enable_callback() + if hasattr(self._fileobj, 'signal_transferring'): + self._fileobj.signal_transferring() + + def signal_not_transferring(self): + self.disable_callback() + if hasattr(self._fileobj, 'signal_not_transferring'): + self._fileobj.signal_not_transferring() + def enable_callback(self): self._callbacks_enabled = True diff --git a/tests/functional/test_download.py b/tests/functional/test_download.py index 4d999138..d15d901f 100644 --- a/tests/functional/test_download.py +++ b/tests/functional/test_download.py @@ -13,6 +13,7 @@ import copy import os import tempfile +import time import shutil import glob @@ -381,6 +382,34 @@ def test_download_empty_object(self): with open(self.filename, 'rb') as f: self.assertEqual(b'', f.read()) + def test_uses_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + self.stream = six.BytesIO(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content)/2) + self._manager = TransferManager(self.client, self.config) + + self.add_head_object_response() + self.add_successful_get_object_responses() + + start = time.time() + future = self.manager.download( + self.bucket, self.key, self.filename, self.extra_args) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + # Ensure that the contents are correct + with open(self.filename, 'rb') as f: + self.assertEqual(self.content, f.read()) + class TestRangedDownload(BaseDownloadTest): # TODO: If you want to add tests outside of this test class and still diff --git a/tests/functional/test_manager.py b/tests/functional/test_manager.py index 9aeb9cca..9d5702d3 100644 --- a/tests/functional/test_manager.py +++ b/tests/functional/test_manager.py @@ -27,18 +27,18 @@ class ArbitraryException(Exception): pass -class CallbackEnablingBody(RawIOBase): - """A mocked body with callback enabling/disabling""" +class SignalTransferringBody(RawIOBase): + """A mocked body with the ability to signal when transfers occur""" def __init__(self): - super(CallbackEnablingBody, self).__init__() - self.enable_callback_call_count = 0 - self.disable_callback_call_count = 0 + super(SignalTransferringBody, self).__init__() + self.signal_transferring_call_count = 0 + self.signal_not_transferring_call_count = 0 - def enable_callback(self): - self.enable_callback_call_count += 1 + def signal_transferring(self): + self.signal_transferring_call_count += 1 - def disable_callback(self): - self.disable_callback_call_count += 1 + def signal_not_transferring(self): + self.signal_not_transferring_call_count += 1 def seek(self, where): pass @@ -128,7 +128,7 @@ def test_cntrl_c_in_context_manager_cancels_incomplete_transfers(self): future.result() def test_enable_disable_callbacks_only_ever_registered_once(self): - body = CallbackEnablingBody() + body = SignalTransferringBody() request = create_request_object({ 'method': 'PUT', 'url': 'https://s3.amazonaws.com', @@ -145,10 +145,10 @@ def test_enable_disable_callbacks_only_ever_registered_once(self): # handlers registered once depite being used for two different # TransferManagers. self.assertEqual( - body.enable_callback_call_count, 1, + body.signal_transferring_call_count, 1, 'The enable_callback() should have only ever been registered once') self.assertEqual( - body.disable_callback_call_count, 1, + body.signal_not_transferring_call_count, 1, 'The disable_callback() should have only ever been registered ' 'once') diff --git a/tests/functional/test_upload.py b/tests/functional/test_upload.py index e880a8e3..06ee2b0c 100644 --- a/tests/functional/test_upload.py +++ b/tests/functional/test_upload.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import os +import time import tempfile import shutil @@ -90,7 +91,16 @@ def collect_body(self, params, model, **kwargs): 'request-created.s3.%s' % model.name, request=request, operation_name=model.name ) - self.sent_bodies.append(params['Body'].read()) + self.sent_bodies.append(self._stream_body(params['Body'])) + + def _stream_body(self, body): + read_amt = 8 * 1024 + data = body.read(read_amt) + collected_body = data + while data: + data = body.read(read_amt) + collected_body += data + return collected_body @property def manager(self): @@ -231,6 +241,31 @@ def test_allowed_upload_params_are_valid(self): for allowed_upload_arg in self._manager.ALLOWED_UPLOAD_ARGS: self.assertIn(allowed_upload_arg, op_model.input_shape.members) + def test_upload_with_bandwidth_limiter(self): + self.content = b'a' * 1024 * 1024 + with open(self.filename, 'wb') as f: + f.write(self.content) + self.config = TransferConfig( + max_request_concurrency=1, max_bandwidth=len(self.content)/2) + self._manager = TransferManager(self.client, self.config) + + self.add_put_object_response_with_default_expected_params() + start = time.time() + future = self.manager.upload(self.filename, self.bucket, self.key) + future.result() + # This is just a smoke test to make sure that the limiter is + # being used and not necessary its exactness. So we set the maximum + # bandwidth to len(content)/2 per sec and make sure that it is + # noticeably slower. Ideally it will take more than two seconds, but + # given tracking at the beginning of transfers are not entirely + # accurate setting at the initial start of a transfer, we give us + # some flexibility by setting the expected time to half of the + # theoretical time to take. + self.assertGreaterEqual(time.time() - start, 1) + + self.assert_expected_client_calls_were_correct() + self.assert_put_object_body_was_correct() + class TestMultipartUpload(BaseUploadTest): __test__ = True diff --git a/tests/unit/test_bandwidth.py b/tests/unit/test_bandwidth.py new file mode 100644 index 00000000..4f35a7a8 --- /dev/null +++ b/tests/unit/test_bandwidth.py @@ -0,0 +1,465 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import os +import shutil +import tempfile + +import mock + +from tests import unittest +from s3transfer.bandwidth import RequestExceededException +from s3transfer.bandwidth import RequestToken +from s3transfer.bandwidth import TimeUtils +from s3transfer.bandwidth import BandwidthLimiter +from s3transfer.bandwidth import BandwidthLimitedStream +from s3transfer.bandwidth import LeakyBucket +from s3transfer.bandwidth import ConsumptionScheduler +from s3transfer.bandwidth import BandwidthRateTracker +from s3transfer.futures import TransferCoordinator + + +class FixedIncrementalTickTimeUtils(TimeUtils): + def __init__(self, seconds_per_tick=1.0): + self._count = 0 + self._seconds_per_tick = seconds_per_tick + + def time(self): + current_count = self._count + self._count += self._seconds_per_tick + return current_count + + +class TestTimeUtils(unittest.TestCase): + @mock.patch('time.time') + def test_time(self, mock_time): + mock_return_val = 1 + mock_time.return_value = mock_return_val + time_utils = TimeUtils() + self.assertEqual(time_utils.time(), mock_return_val) + + @mock.patch('time.sleep') + def test_sleep(self, mock_sleep): + time_utils = TimeUtils() + time_utils.sleep(1) + self.assertEqual( + mock_sleep.call_args_list, + [mock.call(1)] + ) + + +class BaseBandwidthLimitTest(unittest.TestCase): + def setUp(self): + self.leaky_bucket = mock.Mock(LeakyBucket) + self.time_utils = mock.Mock(TimeUtils) + self.tempdir = tempfile.mkdtemp() + self.content = b'a' * 1024 * 1024 + self.filename = os.path.join(self.tempdir, 'myfile') + with open(self.filename, 'wb') as f: + f.write(self.content) + self.coordinator = TransferCoordinator() + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def assert_consume_calls(self, amts): + expected_consume_args = [ + mock.call(amt, mock.ANY) for amt in amts + ] + self.assertEqual( + self.leaky_bucket.consume.call_args_list, + expected_consume_args + ) + + +class TestBandwidthLimiter(BaseBandwidthLimitTest): + def setUp(self): + super(TestBandwidthLimiter, self).setUp() + self.bandwidth_limiter = BandwidthLimiter(self.leaky_bucket) + + def test_get_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.assert_consume_calls(amts=[len(self.content)]) + + def test_get_disabled_bandwidth_limited_stream(self): + with open(self.filename, 'rb') as f: + stream = self.bandwidth_limiter.get_bandwith_limited_stream( + f, self.coordinator, enabled=False) + self.assertIsInstance(stream, BandwidthLimitedStream) + self.assertEqual(stream.read(len(self.content)), self.content) + self.leaky_bucket.consume.assert_not_called() + + +class TestBandwidthLimitedStream(BaseBandwidthLimitTest): + def setUp(self): + super(TestBandwidthLimitedStream, self).setUp() + self.bytes_threshold = 1 + + def tearDown(self): + shutil.rmtree(self.tempdir) + + def get_bandwidth_limited_stream(self, f): + return BandwidthLimitedStream( + f, self.leaky_bucket, self.coordinator, self.time_utils, + self.bytes_threshold) + + def assert_sleep_calls(self, amts): + expected_sleep_args_list = [ + mock.call(amt) for amt in amts + ] + self.assertEqual( + self.time_utils.sleep.call_args_list, + expected_sleep_args_list + ) + + def get_unique_consume_request_tokens(self): + return set( + call_args[0][1] for call_args in + self.leaky_bucket.consume.call_args_list + ) + + def test_read(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[len(self.content)]) + self.assert_sleep_calls(amts=[]) + + def test_retries_on_request_exceeded(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + retry_time = 1 + amt_requested = len(self.content) + self.leaky_bucket.consume.side_effect = [ + RequestExceededException(amt_requested, retry_time), + len(self.content) + ] + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assert_consume_calls(amts=[amt_requested, amt_requested]) + self.assert_sleep_calls(amts=[retry_time]) + + def test_with_transfer_coordinator_exception(self): + self.coordinator.set_exception(ValueError()) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + with self.assertRaises(ValueError): + stream.read(len(self.content)) + + def test_read_when_bandwidth_limiting_disabled(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(len(self.content)) + self.assertEqual(self.content, data) + self.assertFalse(self.leaky_bucket.consume.called) + + def test_read_toggle_disable_enable_bandwidth_limiting(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.disable_bandwidth_limiting() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.enable_bandwidth_limiting() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + def test_seek(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.seek(1) + self.assertEqual( + mock_fileobj.seek.call_args_list, + [mock.call(1)] + ) + + def test_tell(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.tell() + self.assertEqual( + mock_fileobj.tell.call_args_list, + [mock.call()] + ) + + def test_close(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + stream.close() + self.assertEqual( + mock_fileobj.close.call_args_list, + [mock.call()] + ) + + def test_context_manager(self): + mock_fileobj = mock.Mock() + stream = self.get_bandwidth_limited_stream(mock_fileobj) + with stream as stream_handle: + self.assertIs(stream_handle, stream) + self.assertEqual( + mock_fileobj.close.call_args_list, + [mock.call()] + ) + + def test_reuses_request_token(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 1) + + def test_request_tokens_unique_per_stream(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.read(1) + self.assertEqual(len(self.get_unique_consume_request_tokens()), 2) + + def test_call_consume_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + self.assertEqual(stream.read(1), self.content[1:2]) + self.assert_consume_calls(amts=[2]) + + def test_resets_after_reaching_threshold(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + self.assertEqual(stream.read(1), self.content[2:3]) + self.assert_consume_calls(amts=[2]) + + def test_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.close() + self.assert_consume_calls(amts=[1]) + + def test_no_bytes_remaining_on(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(2), self.content[:2]) + self.assert_consume_calls(amts=[2]) + stream.close() + # There should have been no more consume() calls made + # as all bytes have been accounted for in the previous + # consume() call. + self.assert_consume_calls(amts=[2]) + + def test_disable_bandwidth_limiting_with_pending_bytes_seen_on_close(self): + self.bytes_threshold = 2 + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + self.assertEqual(stream.read(1), self.content[:1]) + self.assert_consume_calls(amts=[]) + stream.disable_bandwidth_limiting() + stream.close() + self.assert_consume_calls(amts=[]) + + def test_signal_transferring(self): + with open(self.filename, 'rb') as f: + stream = self.get_bandwidth_limited_stream(f) + stream.signal_not_transferring() + data = stream.read(1) + self.assertEqual(self.content[:1], data) + self.assert_consume_calls(amts=[]) + stream.signal_transferring() + data = stream.read(len(self.content) - 1) + self.assertEqual(self.content[1:], data) + self.assert_consume_calls(amts=[len(self.content) - 1]) + + +class TestLeakyBucket(unittest.TestCase): + def setUp(self): + self.max_rate = 1 + self.time_now = 1.0 + self.time_utils = mock.Mock(TimeUtils) + self.time_utils.time.return_value = self.time_now + self.scheduler = mock.Mock(ConsumptionScheduler) + self.scheduler.is_scheduled.return_value = False + self.rate_tracker = mock.Mock(BandwidthRateTracker) + self.leaky_bucket = LeakyBucket( + self.max_rate, self.time_utils, self.rate_tracker, + self.scheduler + ) + + def set_projected_rate(self, rate): + self.rate_tracker.get_projected_rate.return_value = rate + + def set_retry_time(self, retry_time): + self.scheduler.schedule_consumption.return_value = retry_time + + def assert_recorded_consumed_amt(self, expected_amt): + self.assertEqual( + self.rate_tracker.record_consumption_rate.call_args, + mock.call(expected_amt, self.time_utils.time.return_value)) + + def assert_was_scheduled(self, amt, token): + self.assertEqual( + self.scheduler.schedule_consumption.call_args, + mock.call(amt, token, amt/(self.max_rate)) + ) + + def assert_nothing_scheduled(self): + self.assertFalse(self.scheduler.schedule_consumption.called) + + def assert_processed_request_token(self, request_token): + self.assertEqual( + self.scheduler.process_scheduled_consumption.call_args, + mock.call(request_token) + ) + + def test_consume_under_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate/2) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_at_max_rate(self): + amt = 1 + self.set_projected_rate(self.max_rate) + self.assertEqual(self.leaky_bucket.consume(amt, RequestToken()), amt) + self.assert_recorded_consumed_amt(amt) + self.assert_nothing_scheduled() + + def test_consume_over_max_rate(self): + amt = 1 + retry_time = 2.0 + self.set_projected_rate(self.max_rate + 1) + self.set_retry_time(retry_time) + request_token = RequestToken() + try: + self.leaky_bucket.consume(amt, request_token) + self.fail('A RequestExceededException should have been thrown') + except RequestExceededException as e: + self.assertEqual(e.requested_amt, amt) + self.assertEqual(e.retry_time, retry_time) + self.assert_was_scheduled(amt, request_token) + + def test_consume_with_scheduled_retry(self): + amt = 1 + self.set_projected_rate(self.max_rate + 1) + self.scheduler.is_scheduled.return_value = True + request_token = RequestToken() + self.assertEqual(self.leaky_bucket.consume(amt, request_token), amt) + # Nothing new should have been scheduled but the request token + # should have been processed. + self.assert_nothing_scheduled() + self.assert_processed_request_token(request_token) + + +class TestConsumptionScheduler(unittest.TestCase): + def setUp(self): + self.scheduler = ConsumptionScheduler() + + def test_schedule_consumption(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time) + self.assertEqual(consume_time, actual_wait_time) + + def test_schedule_consumption_for_multiple_requests(self): + token = RequestToken() + consume_time = 5 + actual_wait_time = self.scheduler.schedule_consumption( + 1, token, consume_time) + self.assertEqual(consume_time, actual_wait_time) + + other_consume_time = 3 + other_token = RequestToken() + next_wait_time = self.scheduler.schedule_consumption( + 1, other_token, other_consume_time) + + # This wait time should be the previous time plus its desired + # wait time + self.assertEqual(next_wait_time, consume_time + other_consume_time) + + def test_is_scheduled(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.assertTrue(self.scheduler.is_scheduled(token)) + + def test_is_not_scheduled(self): + self.assertFalse(self.scheduler.is_scheduled(RequestToken())) + + def test_process_scheduled_consumption(self): + token = RequestToken() + consume_time = 5 + self.scheduler.schedule_consumption(1, token, consume_time) + self.scheduler.process_scheduled_consumption(token) + self.assertFalse(self.scheduler.is_scheduled(token)) + different_time = 7 + # The previous consume time should have no affect on the next wait tim + # as it has been completed. + self.assertEqual( + self.scheduler.schedule_consumption(1, token, different_time), + different_time + ) + + +class TestBandwidthRateTracker(unittest.TestCase): + def setUp(self): + self.alpha = 0.8 + self.rate_tracker = BandwidthRateTracker(self.alpha) + + def test_current_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate_after_one_recorded_point(self): + self.rate_tracker.record_consumption_rate(1, 1) + # There is no last time point to do a diff against so return a + # current rate of 0.0 + self.assertEqual(self.rate_tracker.current_rate, 0.0) + + def test_current_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, 0.96) + + def test_get_projected_rate_at_initilizations(self): + self.assertEqual(self.rate_tracker.get_projected_rate(1, 1), 0.0) + + def test_get_projected_rate(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.rate_tracker.record_consumption_rate(1, 2) + projected_rate = self.rate_tracker.get_projected_rate(1, 3) + self.assertEqual(projected_rate, 0.96) + self.rate_tracker.record_consumption_rate(1, 3) + self.assertEqual(self.rate_tracker.current_rate, projected_rate) + + def test_get_projected_rate_for_same_timestamp(self): + self.rate_tracker.record_consumption_rate(1, 1) + self.assertEqual( + self.rate_tracker.get_projected_rate(1, 1), + float('inf') + ) diff --git a/tests/unit/test_download.py b/tests/unit/test_download.py index bc759d65..b2d9fe55 100644 --- a/tests/unit/test_download.py +++ b/tests/unit/test_download.py @@ -26,6 +26,7 @@ from s3transfer.compat import six from s3transfer.compat import SOCKET_ERROR from s3transfer.exceptions import RetriesExceededError +from s3transfer.bandwidth import BandwidthLimiter from s3transfer.download import DownloadFilenameOutputManager from s3transfer.download import DownloadSpecialFilenameOutputManager from s3transfer.download import DownloadSeekableOutputManager @@ -611,6 +612,22 @@ def test_start_index(self): self.stubber.assert_no_pending_responses() self.assert_io_writes([(5, self.content)]) + def test_uses_bandwidth_limiter(self): + bandwidth_limiter = mock.Mock(BandwidthLimiter) + + self.stubber.add_response( + 'get_object', service_response={'Body': self.stream}, + expected_params={'Bucket': self.bucket, 'Key': self.key} + ) + task = self.get_download_task(bandwidth_limiter=bandwidth_limiter) + task() + + self.stubber.assert_no_pending_responses() + self.assertEqual( + bandwidth_limiter.get_bandwith_limited_stream.call_args_list, + [mock.call(mock.ANY, self.transfer_coordinator)] + ) + def test_retries_succeeds(self): self.stubber.add_response( 'get_object', service_response={ diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d431b1d7..f76b2699 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -18,6 +18,8 @@ import time import io +import mock + from tests import unittest from tests import RecordingSubscriber from tests import NonSeekableWriter @@ -511,6 +513,55 @@ def test_close_callbacks_when_context_handler_is_used(self): chunk.read(1) self.assertEqual(self.num_close_callback_calls, 1) + def test_signal_transferring(self): + chunk = ReadFileChunk.from_filename( + self.filename, start_byte=0, chunk_size=3, + callbacks=[self.callback]) + chunk.signal_not_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, []) + chunk.signal_transferring() + chunk.read(1) + self.assertEqual(self.amounts_seen, [1]) + + def test_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_transferring() + self.assertTrue(underlying_stream.signal_transferring.called) + + def test_no_call_signal_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call signal_transferring ' + 'to the underlying stream.' + ) + + def test_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock() + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + chunk.signal_not_transferring() + self.assertTrue(underlying_stream.signal_not_transferring.called) + + def test_no_call_signal_not_transferring_to_underlying_fileobj(self): + underlying_stream = mock.Mock(io.RawIOBase) + underlying_stream.tell.return_value = 0 + chunk = ReadFileChunk(underlying_stream, 3, 3) + try: + chunk.signal_not_transferring() + except AttributeError: + self.fail( + 'The stream should not have tried to call ' + 'signal_not_transferring to the underlying stream.' + ) + class TestStreamReaderProgress(BaseUtilsTest): def test_proxies_to_wrapped_stream(self):