From e3fea835bf29b9c6f8960ce28504691908d29fdb Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Sun, 18 Dec 2022 10:41:56 -0500 Subject: [PATCH 01/12] fix: ignore seek requests to the current position --- smart_open/s3.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index b2111637..5cf13ed6 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -663,9 +663,10 @@ def seek(self, offset, whence=constants.WHENCE_START): whence = constants.WHENCE_START offset += self._current_pos - self._current_pos = self._raw_reader.seek(offset, whence) + if not (whence == constants.WHENCE_START and offset == self._current_pos): + self._current_pos = self._raw_reader.seek(offset, whence) + self._buffer.empty() - self._buffer.empty() self._eof = self._current_pos == self._raw_reader._content_length return self._current_pos From c5f6954a970e96e20e4de31a4d3fc126cb51350b Mon Sep 17 00:00:00 2001 From: Rusty Conover Date: Thu, 26 Jan 2023 23:18:08 -0500 Subject: [PATCH 02/12] fix: adjust test to match new seek behavior --- smart_open/tests/test_s3_version.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/smart_open/tests/test_s3_version.py b/smart_open/tests/test_s3_version.py index 6f9584d0..94cd034d 100644 --- a/smart_open/tests/test_s3_version.py +++ b/smart_open/tests/test_s3_version.py @@ -70,8 +70,10 @@ def test_good_id(self): def test_bad_id(self): """Does passing an invalid version_id exception into the s3 submodule get handled correctly?""" params = {'version_id': 'bad-version-does-not-exist'} - with self.assertRaises(IOError): - open(self.url, 'rb', transport_params=params) + with open(self.url, 'rb', transport_params=params) as fin: + with self.assertRaises(IOError): + fin.read() + def test_bad_mode(self): """Do we correctly handle non-None version when writing?""" From 06016889212f21d8fda632e86c82adcc54a7cce7 Mon Sep 17 00:00:00 2001 From: Brian Beck Date: Wed, 6 Sep 2023 10:31:28 -0600 Subject: [PATCH 03/12] run seek if it is the first time --- smart_open/s3.py | 280 +++++++++++++++------------- smart_open/tests/test_s3_version.py | 60 +++--- 2 files changed, 184 insertions(+), 156 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index 5cf13ed6..0a5846c7 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -31,26 +31,26 @@ DEFAULT_MIN_PART_SIZE = 50 * 1024**2 """Default minimum part size for S3 multipart uploads""" -MIN_MIN_PART_SIZE = 5 * 1024 ** 2 +MIN_MIN_PART_SIZE = 5 * 1024**2 """The absolute minimum permitted by Amazon.""" -SCHEMES = ("s3", "s3n", 's3u', "s3a") +SCHEMES = ("s3", "s3n", "s3u", "s3a") DEFAULT_PORT = 443 -DEFAULT_HOST = 's3.amazonaws.com' +DEFAULT_HOST = "s3.amazonaws.com" DEFAULT_BUFFER_SIZE = 128 * 1024 URI_EXAMPLES = ( - 's3://my_bucket/my_key', - 's3://my_key:my_secret@my_bucket/my_key', - 's3://my_key:my_secret@my_server:my_port@my_bucket/my_key', + "s3://my_bucket/my_key", + "s3://my_key:my_secret@my_bucket/my_key", + "s3://my_key:my_secret@my_server:my_port@my_bucket/my_key", ) _UPLOAD_ATTEMPTS = 6 _SLEEP_SECONDS = 10 # Returned by AWS when we try to seek beyond EOF. -_OUT_OF_RANGE = 'InvalidRange' +_OUT_OF_RANGE = "InvalidRange" class _ClientWrapper: @@ -63,13 +63,14 @@ class _ClientWrapper: This wrapper behaves identically to the client otherwise. """ + def __init__(self, client, kwargs): self.client = client self.kwargs = kwargs def __getattr__(self, method_name): method = getattr(self.client, method_name) - kwargs = self.kwargs.get('S3.Client.%s' % method_name, {}) + kwargs = self.kwargs.get("S3.Client.%s" % method_name, {}) return functools.partial(method, **kwargs) @@ -105,19 +106,19 @@ def parse_uri(uri_as_string): # uri = split_uri.netloc + split_uri.path - if '@' in uri and ':' in uri.split('@')[0]: - auth, uri = uri.split('@', 1) - access_id, access_secret = auth.split(':') + if "@" in uri and ":" in uri.split("@")[0]: + auth, uri = uri.split("@", 1) + access_id, access_secret = auth.split(":") - head, key_id = uri.split('/', 1) - if '@' in head and ':' in head: + head, key_id = uri.split("/", 1) + if "@" in head and ":" in head: ordinary_calling_format = True - host_port, bucket_id = head.split('@') - host, port = host_port.split(':', 1) + host_port, bucket_id = head.split("@") + host, port = host_port.split(":", 1) port = int(port) - elif '@' in head: + elif "@" in head: ordinary_calling_format = True - host, bucket_id = head.split('@') + host, bucket_id = head.split("@") else: bucket_id = head @@ -150,41 +151,41 @@ def _consolidate_params(uri, transport_params): def inject(**kwargs): try: - client_kwargs = transport_params['client_kwargs'] + client_kwargs = transport_params["client_kwargs"] except KeyError: - client_kwargs = transport_params['client_kwargs'] = {} + client_kwargs = transport_params["client_kwargs"] = {} try: - init_kwargs = client_kwargs['S3.Client'] + init_kwargs = client_kwargs["S3.Client"] except KeyError: - init_kwargs = client_kwargs['S3.Client'] = {} + init_kwargs = client_kwargs["S3.Client"] = {} init_kwargs.update(**kwargs) - client = transport_params.get('client') - if client is not None and (uri['access_id'] or uri['access_secret']): + client = transport_params.get("client") + if client is not None and (uri["access_id"] or uri["access_secret"]): logger.warning( - 'ignoring credentials parsed from URL because they conflict with ' + "ignoring credentials parsed from URL because they conflict with " 'transport_params["client"]. Set transport_params["client"] to None ' - 'to suppress this warning.' + "to suppress this warning." ) uri.update(access_id=None, access_secret=None) - elif (uri['access_id'] and uri['access_secret']): + elif uri["access_id"] and uri["access_secret"]: inject( - aws_access_key_id=uri['access_id'], - aws_secret_access_key=uri['access_secret'], + aws_access_key_id=uri["access_id"], + aws_secret_access_key=uri["access_secret"], ) uri.update(access_id=None, access_secret=None) - if client is not None and uri['host'] != DEFAULT_HOST: + if client is not None and uri["host"] != DEFAULT_HOST: logger.warning( - 'ignoring endpoint_url parsed from URL because they conflict with ' + "ignoring endpoint_url parsed from URL because they conflict with " 'transport_params["client"]. Set transport_params["client"] to None ' - 'to suppress this warning.' + "to suppress this warning." ) uri.update(host=None) - elif uri['host'] != DEFAULT_HOST: - inject(endpoint_url='https://%(host)s:%(port)d' % uri) + elif uri["host"] != DEFAULT_HOST: + inject(endpoint_url="https://%(host)s:%(port)d" % uri) uri.update(host=None) return uri, transport_params @@ -192,18 +193,18 @@ def inject(**kwargs): def open_uri(uri, mode, transport_params): deprecated = ( - 'multipart_upload_kwargs', - 'object_kwargs', - 'resource', - 'resource_kwargs', - 'session', - 'singlepart_upload_kwargs', + "multipart_upload_kwargs", + "object_kwargs", + "resource", + "resource_kwargs", + "session", + "singlepart_upload_kwargs", ) detected = [k for k in deprecated if k in transport_params] if detected: doc_url = ( - 'https://github.com/RaRe-Technologies/smart_open/blob/develop/' - 'MIGRATING_FROM_OLDER_VERSIONS.rst' + "https://github.com/RaRe-Technologies/smart_open/blob/develop/" + "MIGRATING_FROM_OLDER_VERSIONS.rst" ) # # We use warnings.warn /w UserWarning instead of logger.warn here because @@ -214,14 +215,14 @@ def open_uri(uri, mode, transport_params): # https://github.com/RaRe-Technologies/smart_open/issues/614 # message = ( - 'ignoring the following deprecated transport parameters: %r. ' - 'See <%s> for details' % (detected, doc_url) + "ignoring the following deprecated transport parameters: %r. " + "See <%s> for details" % (detected, doc_url) ) warnings.warn(message, UserWarning) parsed_uri = parse_uri(uri) parsed_uri, transport_params = _consolidate_params(parsed_uri, transport_params) kwargs = smart_open.utils.check_kwargs(open, transport_params) - return open(parsed_uri['bucket_id'], parsed_uri['key_id'], mode, **kwargs) + return open(parsed_uri["bucket_id"], parsed_uri["key_id"], mode, **kwargs) def open( @@ -280,9 +281,11 @@ def open( disk IO. If you pass in an open file, then you are responsible for cleaning it up after writing completes. """ - logger.debug('%r', locals()) + logger.debug("%r", locals()) if mode not in constants.BINARY_MODES: - raise NotImplementedError('bad mode: %r expected one of %r' % (mode, constants.BINARY_MODES)) + raise NotImplementedError( + "bad mode: %r expected one of %r" % (mode, constants.BINARY_MODES) + ) if (mode == constants.WRITE_BINARY) and (version_id is not None): raise ValueError("version_id must be None when writing") @@ -316,7 +319,7 @@ def open( writebuffer=writebuffer, ) else: - assert False, 'unexpected mode: %r' % mode + assert False, "unexpected mode: %r" % mode fileobj.name = key_id return fileobj @@ -325,14 +328,15 @@ def open( def _get(client, bucket, key, version, range_string): try: if version: - return client.get_object(Bucket=bucket, Key=key, VersionId=version, Range=range_string) + return client.get_object( + Bucket=bucket, Key=key, VersionId=version, Range=range_string + ) else: return client.get_object(Bucket=bucket, Key=key, Range=range_string) except botocore.client.ClientError as error: wrapped_error = IOError( - 'unable to access bucket: %r key: %r version: %r error: %s' % ( - bucket, key, version, error - ) + "unable to access bucket: %r key: %r version: %r error: %s" + % (bucket, key, version, error) ) wrapped_error.backend_error = error raise wrapped_error from error @@ -341,7 +345,7 @@ def _get(client, bucket, key, version, range_string): def _unwrap_ioerror(ioe): """Given an IOError from _get, return the 'Error' dictionary from boto.""" try: - return ioe.backend_error.response['Error'] + return ioe.backend_error.response["Error"] except (AttributeError, KeyError): return None @@ -378,7 +382,9 @@ def seek(self, offset, whence=constants.WHENCE_START): :rtype: int """ if whence not in constants.WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) + raise ValueError( + "invalid whence, expected one of %r" % constants.WHENCE_CHOICES + ) # # Close old body explicitly. @@ -445,9 +451,11 @@ def _open_body(self, start=None, stop=None): except IOError as ioe: # Handle requested content range exceeding content size. error_response = _unwrap_ioerror(ioe) - if error_response is None or error_response.get('Code') != _OUT_OF_RANGE: + if error_response is None or error_response.get("Code") != _OUT_OF_RANGE: raise - self._position = self._content_length = int(error_response['ActualObjectSize']) + self._position = self._content_length = int( + error_response["ActualObjectSize"] + ) self._body = io.BytesIO() else: # @@ -457,14 +465,16 @@ def _open_body(self, start=None, stop=None): # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#checking-retry-attempts-in-an-aws-service-response # logger.debug( - '%s: RetryAttempts: %d', + "%s: RetryAttempts: %d", self, - response['ResponseMetadata']['RetryAttempts'], + response["ResponseMetadata"]["RetryAttempts"], + ) + units, start, stop, length = smart_open.utils.parse_content_range( + response["ContentRange"] ) - units, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange']) self._content_length = length self._position = start - self._body = response['Body'] + self._body = response["Body"] def read(self, size=-1): """Read from the continuous connection with the remote peer.""" @@ -472,7 +482,7 @@ def read(self, size=-1): # This is necessary for the very first read() after __init__(). self._open_body() if self._position >= self._content_length: - return b'' + return b"" # # Boto3 has built-in error handling and retry mechanisms: @@ -499,7 +509,7 @@ def read(self, size=-1): urllib3.exceptions.HTTPError, ) as err: logger.warning( - '%s: caught %r while reading %d bytes, sleeping %ds before retry', + "%s: caught %r while reading %d bytes, sleeping %ds before retry", self, err, size, @@ -511,10 +521,12 @@ def read(self, size=-1): self._position += len(binary) return binary - raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) + raise IOError( + "%s: failed to read %d bytes after %d attempts" % (self, size, attempt) + ) def __str__(self): - return 'smart_open.s3._SeekableReader(%r, %r)' % (self._bucket, self._key) + return "smart_open.s3._SeekableReader(%r, %r)" % (self._bucket, self._key) def _initialize_boto3(rw, client, client_kwargs, bucket, key): @@ -524,8 +536,8 @@ def _initialize_boto3(rw, client, client_kwargs, bucket, key): client_kwargs = {} if client is None: - init_kwargs = client_kwargs.get('S3.Client', {}) - client = boto3.client('s3', **init_kwargs) + init_kwargs = client_kwargs.get("S3.Client", {}) + client = boto3.client("s3", **init_kwargs) assert client rw._client = _ClientWrapper(client, client_kwargs) @@ -564,6 +576,7 @@ def __init__( self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size) self._eof = False self._line_terminator = line_terminator + self._seek_initialized = False # # This member is part of the io.BufferedIOBase interface. @@ -588,7 +601,7 @@ def readable(self): def read(self, size=-1): """Read up to size bytes from the object and return them.""" if size == 0: - return b'' + return b"" elif size < 0: # call read() before setting _current_pos to make sure _content_length is set out = self._read_from_buffer() + self._raw_reader.read() @@ -620,13 +633,13 @@ def readinto(self, b): data = self.read(len(b)) if not data: return 0 - b[:len(data)] = data + b[: len(data)] = data return len(data) def readline(self, limit=-1): """Read up to and including the next newline. Returns the bytes read.""" if limit != -1: - raise NotImplementedError('limits other than -1 not implemented yet') + raise NotImplementedError("limits other than -1 not implemented yet") # # A single line may span multiple buffers. @@ -663,11 +676,15 @@ def seek(self, offset, whence=constants.WHENCE_START): whence = constants.WHENCE_START offset += self._current_pos - if not (whence == constants.WHENCE_START and offset == self._current_pos): + if not self._seek_initialized or not ( + whence == constants.WHENCE_START and offset == self._current_pos + ): self._current_pos = self._raw_reader.seek(offset, whence) self._buffer.empty() self._eof = self._current_pos == self._raw_reader._content_length + + self._seek_initialized = True return self._current_pos def tell(self): @@ -691,7 +708,7 @@ def to_boto3(self, resource): the same S3 object as this instance. Changes to the returned object will not affect the current instance. """ - assert resource, 'resource must be a boto3.resource instance' + assert resource, "resource must be a boto3.resource instance" obj = resource.Object(self._bucket, self._key) if self._version_id is not None: return obj.Version(self._version_id) @@ -713,7 +730,7 @@ def _fill_buffer(self, size=-1): while len(self._buffer) < size and not self._eof: bytes_read = self._buffer.fill(self._raw_reader) if bytes_read == 0: - logger.debug('%s: reached EOF while filling buffer', self) + logger.debug("%s: reached EOF while filling buffer", self) self._eof = True def __str__(self): @@ -751,8 +768,10 @@ def __init__( writebuffer=None, ): if min_part_size < MIN_MIN_PART_SIZE: - logger.warning("S3 requires minimum part size >= 5MB; \ -multipart upload may fail") + logger.warning( + "S3 requires minimum part size >= 5MB; \ +multipart upload may fail" + ) self._min_part_size = min_part_size _initialize_boto3(self, client, client_kwargs, bucket, key) @@ -763,12 +782,11 @@ def __init__( Bucket=bucket, Key=key, ) - self._upload_id = _retry_if_failed(partial)['UploadId'] + self._upload_id = _retry_if_failed(partial)["UploadId"] except botocore.client.ClientError as error: raise ValueError( - 'the bucket %r does not exist, or is forbidden for access (%r)' % ( - bucket, error - ) + "the bucket %r does not exist, or is forbidden for access (%r)" + % (bucket, error) ) from error if writebuffer is None: @@ -801,10 +819,10 @@ def close(self): Bucket=self._bucket, Key=self._key, UploadId=self._upload_id, - MultipartUpload={'Parts': self._parts}, + MultipartUpload={"Parts": self._parts}, ) _retry_if_failed(partial) - logger.debug('%s: completed multipart upload', self) + logger.debug("%s: completed multipart upload", self) elif self._upload_id: # # AWS complains with "The XML you provided was not well-formed or @@ -822,9 +840,9 @@ def close(self): self._client.put_object( Bucket=self._bucket, Key=self._key, - Body=b'', + Body=b"", ) - logger.debug('%s: wrote 0 bytes to imitate multipart upload', self) + logger.debug("%s: wrote 0 bytes to imitate multipart upload", self) self._upload_id = None @property @@ -891,7 +909,7 @@ def to_boto3(self, resource): the same S3 object as this instance. Changes to the returned object will not affect the current instance. """ - assert resource, 'resource must be a boto3.resource instance' + assert resource, "resource must be a boto3.resource instance" return resource.Object(self._bucket, self._key) # @@ -904,7 +922,7 @@ def _upload_next_part(self): self, part_num, self._buf.tell(), - self._total_bytes / 1024.0 ** 3, + self._total_bytes / 1024.0**3, ) self._buf.seek(0) @@ -925,7 +943,7 @@ def _upload_next_part(self): ) ) - self._parts.append({'ETag': upload['ETag'], 'PartNumber': part_num}) + self._parts.append({"ETag": upload["ETag"], "PartNumber": part_num}) logger.debug("%s: upload of part_num #%i finished", self, part_num) self._total_parts += 1 @@ -974,7 +992,9 @@ def __init__( try: self._client.head_bucket(Bucket=bucket) except botocore.client.ClientError as e: - raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket) from e + raise ValueError( + "the bucket %r does not exist, or is forbidden for access" % bucket + ) from e if writebuffer is None: self._buf = io.BytesIO() @@ -1008,7 +1028,9 @@ def close(self): ) except botocore.client.ClientError as e: raise ValueError( - 'the bucket %r does not exist, or is forbidden for access' % self._bucket) from e + "the bucket %r does not exist, or is forbidden for access" + % self._bucket + ) from e logger.debug("%s: direct upload finished", self) self._buf = None @@ -1050,7 +1072,8 @@ def write(self, b): interface implementation) into the buffer. Content of the buffer will be written to S3 on close as a single-part upload. - For more information about buffers, see https://docs.python.org/3/c-api/buffer.html""" + For more information about buffers, see https://docs.python.org/3/c-api/buffer.html + """ length = self._buf.write(b) self._total_bytes += length @@ -1073,32 +1096,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __str__(self): - return "smart_open.s3.SinglepartWriter(%r, %r)" % (self._object.bucket_name, self._object.key) + return "smart_open.s3.SinglepartWriter(%r, %r)" % ( + self._object.bucket_name, + self._object.key, + ) def __repr__(self): - return "smart_open.s3.SinglepartWriter(bucket=%r, key=%r)" % (self._bucket, self._key) + return "smart_open.s3.SinglepartWriter(bucket=%r, key=%r)" % ( + self._bucket, + self._key, + ) def _retry_if_failed( - partial, - attempts=_UPLOAD_ATTEMPTS, - sleep_seconds=_SLEEP_SECONDS, - exceptions=None): + partial, attempts=_UPLOAD_ATTEMPTS, sleep_seconds=_SLEEP_SECONDS, exceptions=None +): if exceptions is None: - exceptions = (botocore.exceptions.EndpointConnectionError, ) + exceptions = (botocore.exceptions.EndpointConnectionError,) for attempt in range(attempts): try: return partial() except exceptions: logger.critical( - 'Unable to connect to the endpoint. Check your network connection. ' - 'Sleeping and retrying %d more times ' - 'before giving up.' % (attempts - attempt - 1) + "Unable to connect to the endpoint. Check your network connection. " + "Sleeping and retrying %d more times " + "before giving up." % (attempts - attempt - 1) ) time.sleep(sleep_seconds) else: - logger.critical('Unable to connect to the endpoint. Giving up.') - raise IOError('Unable to connect to the endpoint after %d attempts' % attempts) + logger.critical("Unable to connect to the endpoint. Giving up.") + raise IOError("Unable to connect to the endpoint after %d attempts" % attempts) def _accept_all(key): @@ -1106,13 +1133,14 @@ def _accept_all(key): def iter_bucket( - bucket_name, - prefix='', - accept_key=None, - key_limit=None, - workers=16, - retries=3, - **session_kwargs): + bucket_name, + prefix="", + accept_key=None, + key_limit=None, + workers=16, + retries=3, + **session_kwargs +): """ Iterate and download all S3 objects under `s3://bucket_name/prefix`. @@ -1178,15 +1206,11 @@ def iter_bucket( total_size, key_no = 0, -1 key_iterator = _list_bucket( - bucket_name, - prefix=prefix, - accept_key=accept_key, - **session_kwargs) + bucket_name, prefix=prefix, accept_key=accept_key, **session_kwargs + ) download_key = functools.partial( - _download_key, - bucket_name=bucket_name, - retries=retries, - **session_kwargs) + _download_key, bucket_name=bucket_name, retries=retries, **session_kwargs + ) with smart_open.concurrency.create_pool(processes=workers) as pool: result_iterator = pool.imap_unordered(download_key, key_iterator) @@ -1197,7 +1221,10 @@ def iter_bucket( if key_no % 1000 == 0: logger.info( "yielding key #%i: %s, size %i (total %.1fMB)", - key_no, key, len(content), total_size / 1024.0 ** 2 + key_no, + key, + len(content), + total_size / 1024.0**2, ) yield key, content total_size += len(content) @@ -1210,7 +1237,10 @@ def iter_bucket( # after we listed the contents of the bucket, but before we # downloaded the object. # - if not ('Error' in err.response and err.response['Error'].get('Code') == '404'): + if not ( + "Error" in err.response + and err.response["Error"].get("Code") == "404" + ): raise err except StopIteration: break @@ -1218,13 +1248,9 @@ def iter_bucket( logger.info("processed %i keys, total size %i" % (key_no + 1, total_size)) -def _list_bucket( - bucket_name, - prefix='', - accept_key=lambda k: True, - **session_kwargs): +def _list_bucket(bucket_name, prefix="", accept_key=lambda k: True, **session_kwargs): session = boto3.session.Session(**session_kwargs) - client = session.client('s3') + client = session.client("s3") ctoken = None while True: @@ -1236,28 +1262,28 @@ def _list_bucket( kwargs = dict(Bucket=bucket_name, Prefix=prefix) response = client.list_objects_v2(**kwargs) try: - content = response['Contents'] + content = response["Contents"] except KeyError: pass else: for c in content: - key = c['Key'] + key = c["Key"] if accept_key(key): yield key - ctoken = response.get('NextContinuationToken', None) + ctoken = response.get("NextContinuationToken", None) if not ctoken: break def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs): if bucket_name is None: - raise ValueError('bucket_name may not be None') + raise ValueError("bucket_name may not be None") # # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html#multithreading-or-multiprocessing-with-resources # session = boto3.session.Session(**session_kwargs) - s3 = session.resource('s3') + s3 = session.resource("s3") bucket = s3.Bucket(bucket_name) # Sometimes, https://github.com/boto/boto/issues/2409 can happen diff --git a/smart_open/tests/test_s3_version.py b/smart_open/tests/test_s3_version.py index 94cd034d..b566a1fc 100644 --- a/smart_open/tests/test_s3_version.py +++ b/smart_open/tests/test_s3_version.py @@ -11,14 +11,14 @@ from smart_open import open -BUCKET_NAME = 'test-smartopen' -KEY_NAME = 'test-key' +BUCKET_NAME = "test-smartopen" +KEY_NAME = "test-key" logger = logging.getLogger(__name__) -_resource = functools.partial(boto3.resource, region_name='us-east-1') +_resource = functools.partial(boto3.resource, region_name="us-east-1") def get_versions(bucket, key): @@ -26,7 +26,7 @@ def get_versions(bucket, key): return [ v.id for v in sorted( - _resource('s3').Bucket(bucket).object_versions.filter(Prefix=key), + _resource("s3").Bucket(bucket).object_versions.filter(Prefix=key), key=lambda version: version.last_modified, ) ] @@ -39,78 +39,80 @@ def setUp(self): # Each run of this test reuses the BUCKET_NAME, but works with a # different key for isolation. # - resource = _resource('s3') + resource = _resource("s3") resource.create_bucket(Bucket=BUCKET_NAME).wait_until_exists() resource.BucketVersioning(BUCKET_NAME).enable() - self.key = 'test-write-key-{}'.format(uuid.uuid4().hex) + self.key = "test-write-key-{}".format(uuid.uuid4().hex) self.url = "s3://%s/%s" % (BUCKET_NAME, self.key) - self.test_ver1 = u"String version 1.0".encode('utf8') - self.test_ver2 = u"String version 2.0".encode('utf8') + self.test_ver1 = "String version 1.0".encode("utf8") + self.test_ver2 = "String version 2.0".encode("utf8") bucket = resource.Bucket(BUCKET_NAME) bucket.put_object(Key=self.key, Body=self.test_ver1) - logging.critical('versions after first write: %r', get_versions(BUCKET_NAME, self.key)) + logging.critical( + "versions after first write: %r", get_versions(BUCKET_NAME, self.key) + ) time.sleep(3) bucket.put_object(Key=self.key, Body=self.test_ver2) self.versions = get_versions(BUCKET_NAME, self.key) - logging.critical('versions after second write: %r', get_versions(BUCKET_NAME, self.key)) + logging.critical( + "versions after second write: %r", get_versions(BUCKET_NAME, self.key) + ) assert len(self.versions) == 2 def test_good_id(self): """Does passing the version_id parameter into the s3 submodule work correctly when reading?""" - params = {'version_id': self.versions[0]} - with open(self.url, mode='rb', transport_params=params) as fin: + params = {"version_id": self.versions[0]} + with open(self.url, mode="rb", transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver1) def test_bad_id(self): """Does passing an invalid version_id exception into the s3 submodule get handled correctly?""" - params = {'version_id': 'bad-version-does-not-exist'} - with open(self.url, 'rb', transport_params=params) as fin: - with self.assertRaises(IOError): - fin.read() - + params = {"version_id": "bad-version-does-not-exist"} + with self.assertRaises(IOError): + open(self.url, "rb", transport_params=params) def test_bad_mode(self): """Do we correctly handle non-None version when writing?""" - params = {'version_id': self.versions[0]} + params = {"version_id": self.versions[0]} with self.assertRaises(ValueError): - open(self.url, 'wb', transport_params=params) + open(self.url, "wb", transport_params=params) def test_no_version(self): """Passing in no version at all gives the newest version of the file?""" - with open(self.url, 'rb') as fin: + with open(self.url, "rb") as fin: actual = fin.read() self.assertEqual(actual, self.test_ver2) def test_newest_version(self): """Passing in the newest version explicitly gives the most recent content?""" - params = {'version_id': self.versions[1]} - with open(self.url, mode='rb', transport_params=params) as fin: + params = {"version_id": self.versions[1]} + with open(self.url, mode="rb", transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver2) def test_oldest_version(self): """Passing in the oldest version gives the oldest content?""" - params = {'version_id': self.versions[0]} - with open(self.url, mode='rb', transport_params=params) as fin: + params = {"version_id": self.versions[0]} + with open(self.url, mode="rb", transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver1) def test_version_to_boto3(self): """Passing in the oldest version gives the oldest content?""" self.versions = get_versions(BUCKET_NAME, self.key) - params = {'version_id': self.versions[0]} - with open(self.url, mode='rb', transport_params=params) as fin: - returned_obj = fin.to_boto3(_resource('s3')) + params = {"version_id": self.versions[0]} + with open(self.url, mode="rb", transport_params=params) as fin: + returned_obj = fin.to_boto3(_resource("s3")) - boto3_body = boto3_body = returned_obj.get()['Body'].read() + boto3_body = boto3_body = returned_obj.get()["Body"].read() self.assertEqual(boto3_body, self.test_ver1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 089340acb3501181d727db4db74cf9f4161853ff Mon Sep 17 00:00:00 2001 From: Christian Jensen Date: Wed, 18 Jan 2023 19:03:16 -0800 Subject: [PATCH 04/12] Add required import for example to work (#756) If a person were to simply copy this code block it would use the built in `open` and would not work. Adding in the correct import makes this block a bit easier for a simple copy paste. --- README.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/README.rst b/README.rst index 64435232..79b45844 100644 --- a/README.rst +++ b/README.rst @@ -151,6 +151,7 @@ For the sake of simplicity, the examples below assume you have all the dependenc .. code-block:: python >>> import os, boto3 + >>> from smart_open import open >>> >>> # stream content *into* S3 (write mode) using a custom session >>> session = boto3.Session( From 815202c64ac51968a7632c9c69a21df9fcd9c2da Mon Sep 17 00:00:00 2001 From: tooptoop4 <33283496+tooptoop4@users.noreply.github.com> Date: Mon, 3 Jul 2023 07:24:27 +1000 Subject: [PATCH 05/12] run tests on py3.11 (#774) --- .github/workflows/python-package.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c2cb5b3d..767a691e 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -6,10 +6,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Setup up Python 3.10 + - name: Setup up Python 3.11 uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: Update pip run: python -m pip install -U pip @@ -26,15 +26,15 @@ jobs: strategy: matrix: include: - - {python: '3.7', os: ubuntu-20.04} - {python: '3.8', os: ubuntu-20.04} - {python: '3.9', os: ubuntu-20.04} - {python: '3.10', os: ubuntu-20.04} + - {python: '3.11', os: ubuntu-20.04} - - {python: '3.7', os: windows-2019} - {python: '3.8', os: windows-2019} - {python: '3.9', os: windows-2019} - {python: '3.10', os: windows-2019} + - {python: '3.11', os: windows-2019} steps: - uses: actions/checkout@v2 @@ -63,10 +63,10 @@ jobs: strategy: matrix: include: - - {python: '3.7', os: ubuntu-20.04} - {python: '3.8', os: ubuntu-20.04} - {python: '3.9', os: ubuntu-20.04} - {python: '3.10', os: ubuntu-20.04} + - {python: '3.11', os: ubuntu-20.04} # # Some of the doctests don't pass on Windows because of Windows-specific @@ -105,10 +105,10 @@ jobs: strategy: matrix: include: - - {python: '3.7', os: ubuntu-20.04, moto_server: true} - {python: '3.8', os: ubuntu-20.04} - {python: '3.9', os: ubuntu-20.04} - {python: '3.10', os: ubuntu-20.04} + - {python: '3.11', os: ubuntu-20.04} # Not sure why we exclude these, perhaps for historical reasons? # @@ -159,10 +159,10 @@ jobs: strategy: matrix: include: - - {python: '3.7', os: ubuntu-20.04} - {python: '3.8', os: ubuntu-20.04} - {python: '3.9', os: ubuntu-20.04} - {python: '3.10', os: ubuntu-20.04} + - {python: '3.11', os: ubuntu-20.04} # - {python: '3.7', os: windows-2019} # - {python: '3.8', os: windows-2019} From 543ff2e3a32e96d12681535a45fc9afdf474819f Mon Sep 17 00:00:00 2001 From: Brian Beck Date: Fri, 1 Sep 2023 20:38:08 -0600 Subject: [PATCH 06/12] add type command to ftp (#781) --- smart_open/ftp.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/smart_open/ftp.py b/smart_open/ftp.py index 3dbe26f1..7d4a5ad5 100644 --- a/smart_open/ftp.py +++ b/smart_open/ftp.py @@ -14,6 +14,7 @@ import smart_open.utils from ftplib import FTP, FTP_TLS, error_reply import types + logger = logging.getLogger(__name__) SCHEMES = ("ftp", "ftps") @@ -55,8 +56,13 @@ def open_uri(uri, mode, transport_params): uri_path = parsed_uri.pop("uri_path") scheme = parsed_uri.pop("scheme") secure_conn = True if scheme == "ftps" else False - return open(uri_path, mode, secure_connection=secure_conn, - transport_params=transport_params, **parsed_uri) + return open( + uri_path, + mode, + secure_connection=secure_conn, + transport_params=transport_params, + **parsed_uri, + ) def convert_transport_params_to_args(transport_params): @@ -90,7 +96,9 @@ def _connect(hostname, username, port, password, secure_connection, transport_pa try: ftp.login(username, password) except error_reply as e: - logger.error("Unable to login to FTP server: try checking the username and password!") + logger.error( + "Unable to login to FTP server: try checking the username and password!" + ) raise e if secure_connection: ftp.prot_p() @@ -99,7 +107,7 @@ def _connect(hostname, username, port, password, secure_connection, transport_pa def open( path, - mode="r", + mode="rb", host=None, user=None, password=None, @@ -146,6 +154,7 @@ def open( except KeyError: raise ValueError(f"unsupported mode: {mode!r}") ftp_mode, file_obj_mode = mode_to_ftp_cmds[mode] + conn.voidcmd("TYPE I") socket = conn.transfercmd(f"{ftp_mode} {path}") fobj = socket.makefile(file_obj_mode) @@ -153,6 +162,7 @@ def full_close(self): self.orig_close() self.socket.close() self.conn.close() + fobj.orig_close = fobj.close fobj.socket = socket fobj.conn = conn From 99ee60377545df54b214d6b6daac2b6636746d41 Mon Sep 17 00:00:00 2001 From: tooptoop4 <33283496+tooptoop4@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:35:54 +1000 Subject: [PATCH 07/12] Add python 3.11 to setup.py (#775) --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 57ad9e66..8c08b8e9 100644 --- a/setup.py +++ b/setup.py @@ -95,6 +95,7 @@ def read(fname): 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', 'Topic :: System :: Distributed Computing', 'Topic :: Database :: Front-Ends', ], From b528239425be0d539c1c614c9991242c46bc9259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Cohen?= Date: Wed, 6 Sep 2023 05:38:54 +0200 Subject: [PATCH 08/12] Fixes KeyError when retrieving empty but existing object from S3 (#771) * fix: Fixes KeyError when retrieving empty file from S3 * Add test --- smart_open/s3.py | 33 ++++++++++++++++++++++++--------- smart_open/tests/test_s3.py | 11 +++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index 0a5846c7..e4179968 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -327,12 +327,13 @@ def open( def _get(client, bucket, key, version, range_string): try: + params = dict(Bucket=bucket, Key=key) if version: - return client.get_object( - Bucket=bucket, Key=key, VersionId=version, Range=range_string - ) - else: - return client.get_object(Bucket=bucket, Key=key, Range=range_string) + params["VersionId"] = version + if range_string: + params["Range"] = range_string + + return client.get_object(**params) except botocore.client.ClientError as error: wrapped_error = IOError( "unable to access bucket: %r key: %r version: %r error: %s" @@ -453,10 +454,21 @@ def _open_body(self, start=None, stop=None): error_response = _unwrap_ioerror(ioe) if error_response is None or error_response.get("Code") != _OUT_OF_RANGE: raise - self._position = self._content_length = int( - error_response["ActualObjectSize"] - ) - self._body = io.BytesIO() + try: + self._position = self._content_length = int( + error_response["ActualObjectSize"] + ) + self._body = io.BytesIO() + except KeyError: + response = _get( + self._client, + self._bucket, + self._key, + self._version_id, + None, + ) + self._position = self._content_length = response["ContentLength"] + self._body = response["Body"] else: # # Keep track of how many times boto3's built-in retry mechanism @@ -472,6 +484,9 @@ def _open_body(self, start=None, stop=None): units, start, stop, length = smart_open.utils.parse_content_range( response["ContentRange"] ) + _, start, stop, length = smart_open.utils.parse_content_range( + response["ContentRange"] + ) self._content_length = length self._position = start self._body = response["Body"] diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index a91a731e..fa907101 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -73,6 +73,8 @@ def mock_get(*args, **kwargs): error_response['ActualObjectSize'] = actual_size error_response['Code'] = 'InvalidRange' error_response['Message'] = 'The requested range is not satisfiable' + if actual_size is None: + error_response.pop('ActualObjectSize', None) raise with mock.patch('smart_open.s3._get', new=mock_get): @@ -399,6 +401,15 @@ def test_read_empty_file(self): self.assertEqual(data, b'') + def test_read_empty_file_no_actual_size(self): + _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=b'') + + with self.assertApiCalls(GetObject=2), patch_invalid_range_response(None): + with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin: + data = fin.read() + + self.assertEqual(data, b'') + @moto.mock_s3 class MultipartWriterTest(unittest.TestCase): From 78d109d87cd1f78a6e778340e31d106c4e486edb Mon Sep 17 00:00:00 2001 From: Ron Reiter Date: Wed, 6 Sep 2023 06:43:06 +0300 Subject: [PATCH 09/12] bugfix: when read size > chunk size, return read size and not chunk size (#767) --- smart_open/azure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smart_open/azure.py b/smart_open/azure.py index 96f944a1..ccc19059 100644 --- a/smart_open/azure.py +++ b/smart_open/azure.py @@ -306,7 +306,7 @@ def read(self, size=-1): if self._position == self._size: return self._read_from_buffer() - self._fill_buffer() + self._fill_buffer(size) return self._read_from_buffer(size) def read1(self, size=-1): From 347f50c462c12b4e822757f01724a6a4bed64e11 Mon Sep 17 00:00:00 2001 From: Brian Beck Date: Wed, 6 Sep 2023 18:03:25 -0600 Subject: [PATCH 10/12] undo formatting --- smart_open/s3.py | 265 ++++++++++++++++++++++------------------------- 1 file changed, 124 insertions(+), 141 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index a0ddd0db..135dde15 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -31,26 +31,26 @@ DEFAULT_MIN_PART_SIZE = 50 * 1024**2 """Default minimum part size for S3 multipart uploads""" -MIN_MIN_PART_SIZE = 5 * 1024**2 +MIN_MIN_PART_SIZE = 5 * 1024 ** 2 """The absolute minimum permitted by Amazon.""" -SCHEMES = ("s3", "s3n", "s3u", "s3a") +SCHEMES = ("s3", "s3n", 's3u', "s3a") DEFAULT_PORT = 443 -DEFAULT_HOST = "s3.amazonaws.com" +DEFAULT_HOST = 's3.amazonaws.com' DEFAULT_BUFFER_SIZE = 128 * 1024 URI_EXAMPLES = ( - "s3://my_bucket/my_key", - "s3://my_key:my_secret@my_bucket/my_key", - "s3://my_key:my_secret@my_server:my_port@my_bucket/my_key", + 's3://my_bucket/my_key', + 's3://my_key:my_secret@my_bucket/my_key', + 's3://my_key:my_secret@my_server:my_port@my_bucket/my_key', ) _UPLOAD_ATTEMPTS = 6 _SLEEP_SECONDS = 10 # Returned by AWS when we try to seek beyond EOF. -_OUT_OF_RANGE = "InvalidRange" +_OUT_OF_RANGE = 'InvalidRange' class _ClientWrapper: @@ -63,14 +63,13 @@ class _ClientWrapper: This wrapper behaves identically to the client otherwise. """ - def __init__(self, client, kwargs): self.client = client self.kwargs = kwargs def __getattr__(self, method_name): method = getattr(self.client, method_name) - kwargs = self.kwargs.get("S3.Client.%s" % method_name, {}) + kwargs = self.kwargs.get('S3.Client.%s' % method_name, {}) return functools.partial(method, **kwargs) @@ -106,19 +105,19 @@ def parse_uri(uri_as_string): # uri = split_uri.netloc + split_uri.path - if "@" in uri and ":" in uri.split("@")[0]: - auth, uri = uri.split("@", 1) - access_id, access_secret = auth.split(":") + if '@' in uri and ':' in uri.split('@')[0]: + auth, uri = uri.split('@', 1) + access_id, access_secret = auth.split(':') - head, key_id = uri.split("/", 1) - if "@" in head and ":" in head: + head, key_id = uri.split('/', 1) + if '@' in head and ':' in head: ordinary_calling_format = True - host_port, bucket_id = head.split("@") - host, port = host_port.split(":", 1) + host_port, bucket_id = head.split('@') + host, port = host_port.split(':', 1) port = int(port) - elif "@" in head: + elif '@' in head: ordinary_calling_format = True - host, bucket_id = head.split("@") + host, bucket_id = head.split('@') else: bucket_id = head @@ -151,41 +150,41 @@ def _consolidate_params(uri, transport_params): def inject(**kwargs): try: - client_kwargs = transport_params["client_kwargs"] + client_kwargs = transport_params['client_kwargs'] except KeyError: - client_kwargs = transport_params["client_kwargs"] = {} + client_kwargs = transport_params['client_kwargs'] = {} try: - init_kwargs = client_kwargs["S3.Client"] + init_kwargs = client_kwargs['S3.Client'] except KeyError: - init_kwargs = client_kwargs["S3.Client"] = {} + init_kwargs = client_kwargs['S3.Client'] = {} init_kwargs.update(**kwargs) - client = transport_params.get("client") - if client is not None and (uri["access_id"] or uri["access_secret"]): + client = transport_params.get('client') + if client is not None and (uri['access_id'] or uri['access_secret']): logger.warning( - "ignoring credentials parsed from URL because they conflict with " + 'ignoring credentials parsed from URL because they conflict with ' 'transport_params["client"]. Set transport_params["client"] to None ' - "to suppress this warning." + 'to suppress this warning.' ) uri.update(access_id=None, access_secret=None) - elif uri["access_id"] and uri["access_secret"]: + elif (uri['access_id'] and uri['access_secret']): inject( - aws_access_key_id=uri["access_id"], - aws_secret_access_key=uri["access_secret"], + aws_access_key_id=uri['access_id'], + aws_secret_access_key=uri['access_secret'], ) uri.update(access_id=None, access_secret=None) - if client is not None and uri["host"] != DEFAULT_HOST: + if client is not None and uri['host'] != DEFAULT_HOST: logger.warning( - "ignoring endpoint_url parsed from URL because they conflict with " + 'ignoring endpoint_url parsed from URL because they conflict with ' 'transport_params["client"]. Set transport_params["client"] to None ' - "to suppress this warning." + 'to suppress this warning.' ) uri.update(host=None) - elif uri["host"] != DEFAULT_HOST: - inject(endpoint_url="https://%(host)s:%(port)d" % uri) + elif uri['host'] != DEFAULT_HOST: + inject(endpoint_url='https://%(host)s:%(port)d' % uri) uri.update(host=None) return uri, transport_params @@ -193,18 +192,18 @@ def inject(**kwargs): def open_uri(uri, mode, transport_params): deprecated = ( - "multipart_upload_kwargs", - "object_kwargs", - "resource", - "resource_kwargs", - "session", - "singlepart_upload_kwargs", + 'multipart_upload_kwargs', + 'object_kwargs', + 'resource', + 'resource_kwargs', + 'session', + 'singlepart_upload_kwargs', ) detected = [k for k in deprecated if k in transport_params] if detected: doc_url = ( - "https://github.com/RaRe-Technologies/smart_open/blob/develop/" - "MIGRATING_FROM_OLDER_VERSIONS.rst" + 'https://github.com/RaRe-Technologies/smart_open/blob/develop/' + 'MIGRATING_FROM_OLDER_VERSIONS.rst' ) # # We use warnings.warn /w UserWarning instead of logger.warn here because @@ -215,14 +214,14 @@ def open_uri(uri, mode, transport_params): # https://github.com/RaRe-Technologies/smart_open/issues/614 # message = ( - "ignoring the following deprecated transport parameters: %r. " - "See <%s> for details" % (detected, doc_url) + 'ignoring the following deprecated transport parameters: %r. ' + 'See <%s> for details' % (detected, doc_url) ) warnings.warn(message, UserWarning) parsed_uri = parse_uri(uri) parsed_uri, transport_params = _consolidate_params(parsed_uri, transport_params) kwargs = smart_open.utils.check_kwargs(open, transport_params) - return open(parsed_uri["bucket_id"], parsed_uri["key_id"], mode, **kwargs) + return open(parsed_uri['bucket_id'], parsed_uri['key_id'], mode, **kwargs) def open( @@ -281,11 +280,9 @@ def open( disk IO. If you pass in an open file, then you are responsible for cleaning it up after writing completes. """ - logger.debug("%r", locals()) + logger.debug('%r', locals()) if mode not in constants.BINARY_MODES: - raise NotImplementedError( - "bad mode: %r expected one of %r" % (mode, constants.BINARY_MODES) - ) + raise NotImplementedError('bad mode: %r expected one of %r' % (mode, constants.BINARY_MODES)) if (mode == constants.WRITE_BINARY) and (version_id is not None): raise ValueError("version_id must be None when writing") @@ -319,7 +316,7 @@ def open( writebuffer=writebuffer, ) else: - assert False, "unexpected mode: %r" % mode + assert False, 'unexpected mode: %r' % mode fileobj.name = key_id return fileobj @@ -336,8 +333,9 @@ def _get(client, bucket, key, version, range_string): return client.get_object(**params) except botocore.client.ClientError as error: wrapped_error = IOError( - "unable to access bucket: %r key: %r version: %r error: %s" - % (bucket, key, version, error) + 'unable to access bucket: %r key: %r version: %r error: %s' % ( + bucket, key, version, error + ) ) wrapped_error.backend_error = error raise wrapped_error from error @@ -346,7 +344,7 @@ def _get(client, bucket, key, version, range_string): def _unwrap_ioerror(ioe): """Given an IOError from _get, return the 'Error' dictionary from boto.""" try: - return ioe.backend_error.response["Error"] + return ioe.backend_error.response['Error'] except (AttributeError, KeyError): return None @@ -383,9 +381,7 @@ def seek(self, offset, whence=constants.WHENCE_START): :rtype: int """ if whence not in constants.WHENCE_CHOICES: - raise ValueError( - "invalid whence, expected one of %r" % constants.WHENCE_CHOICES - ) + raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) # # Close old body explicitly. @@ -452,7 +448,7 @@ def _open_body(self, start=None, stop=None): except IOError as ioe: # Handle requested content range exceeding content size. error_response = _unwrap_ioerror(ioe) - if error_response is None or error_response.get("Code") != _OUT_OF_RANGE: + if error_response is None or error_response.get('Code') != _OUT_OF_RANGE: raise try: self._position = self._content_length = int(error_response['ActualObjectSize']) @@ -475,17 +471,14 @@ def _open_body(self, start=None, stop=None): # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#checking-retry-attempts-in-an-aws-service-response # logger.debug( - "%s: RetryAttempts: %d", + '%s: RetryAttempts: %d', self, - response["ResponseMetadata"]["RetryAttempts"], - ) - units, start, stop, length = smart_open.utils.parse_content_range( - response["ContentRange"] + response['ResponseMetadata']['RetryAttempts'], ) _, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange']) self._content_length = length self._position = start - self._body = response["Body"] + self._body = response['Body'] def read(self, size=-1): """Read from the continuous connection with the remote peer.""" @@ -493,7 +486,7 @@ def read(self, size=-1): # This is necessary for the very first read() after __init__(). self._open_body() if self._position >= self._content_length: - return b"" + return b'' # # Boto3 has built-in error handling and retry mechanisms: @@ -520,7 +513,7 @@ def read(self, size=-1): urllib3.exceptions.HTTPError, ) as err: logger.warning( - "%s: caught %r while reading %d bytes, sleeping %ds before retry", + '%s: caught %r while reading %d bytes, sleeping %ds before retry', self, err, size, @@ -532,12 +525,10 @@ def read(self, size=-1): self._position += len(binary) return binary - raise IOError( - "%s: failed to read %d bytes after %d attempts" % (self, size, attempt) - ) + raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) def __str__(self): - return "smart_open.s3._SeekableReader(%r, %r)" % (self._bucket, self._key) + return 'smart_open.s3._SeekableReader(%r, %r)' % (self._bucket, self._key) def _initialize_boto3(rw, client, client_kwargs, bucket, key): @@ -547,8 +538,8 @@ def _initialize_boto3(rw, client, client_kwargs, bucket, key): client_kwargs = {} if client is None: - init_kwargs = client_kwargs.get("S3.Client", {}) - client = boto3.client("s3", **init_kwargs) + init_kwargs = client_kwargs.get('S3.Client', {}) + client = boto3.client('s3', **init_kwargs) assert client rw._client = _ClientWrapper(client, client_kwargs) @@ -612,7 +603,7 @@ def readable(self): def read(self, size=-1): """Read up to size bytes from the object and return them.""" if size == 0: - return b"" + return b'' elif size < 0: # call read() before setting _current_pos to make sure _content_length is set out = self._read_from_buffer() + self._raw_reader.read() @@ -644,13 +635,13 @@ def readinto(self, b): data = self.read(len(b)) if not data: return 0 - b[: len(data)] = data + b[:len(data)] = data return len(data) def readline(self, limit=-1): """Read up to and including the next newline. Returns the bytes read.""" if limit != -1: - raise NotImplementedError("limits other than -1 not implemented yet") + raise NotImplementedError('limits other than -1 not implemented yet') # # A single line may span multiple buffers. @@ -691,6 +682,7 @@ def seek(self, offset, whence=constants.WHENCE_START): whence == constants.WHENCE_START and offset == self._current_pos ): self._current_pos = self._raw_reader.seek(offset, whence) + self._buffer.empty() self._eof = self._current_pos == self._raw_reader._content_length @@ -719,7 +711,7 @@ def to_boto3(self, resource): the same S3 object as this instance. Changes to the returned object will not affect the current instance. """ - assert resource, "resource must be a boto3.resource instance" + assert resource, 'resource must be a boto3.resource instance' obj = resource.Object(self._bucket, self._key) if self._version_id is not None: return obj.Version(self._version_id) @@ -741,7 +733,7 @@ def _fill_buffer(self, size=-1): while len(self._buffer) < size and not self._eof: bytes_read = self._buffer.fill(self._raw_reader) if bytes_read == 0: - logger.debug("%s: reached EOF while filling buffer", self) + logger.debug('%s: reached EOF while filling buffer', self) self._eof = True def __str__(self): @@ -779,10 +771,8 @@ def __init__( writebuffer=None, ): if min_part_size < MIN_MIN_PART_SIZE: - logger.warning( - "S3 requires minimum part size >= 5MB; \ -multipart upload may fail" - ) + logger.warning("S3 requires minimum part size >= 5MB; \ +multipart upload may fail") self._min_part_size = min_part_size _initialize_boto3(self, client, client_kwargs, bucket, key) @@ -793,11 +783,12 @@ def __init__( Bucket=bucket, Key=key, ) - self._upload_id = _retry_if_failed(partial)["UploadId"] + self._upload_id = _retry_if_failed(partial)['UploadId'] except botocore.client.ClientError as error: raise ValueError( - "the bucket %r does not exist, or is forbidden for access (%r)" - % (bucket, error) + 'the bucket %r does not exist, or is forbidden for access (%r)' % ( + bucket, error + ) ) from error if writebuffer is None: @@ -830,10 +821,10 @@ def close(self): Bucket=self._bucket, Key=self._key, UploadId=self._upload_id, - MultipartUpload={"Parts": self._parts}, + MultipartUpload={'Parts': self._parts}, ) _retry_if_failed(partial) - logger.debug("%s: completed multipart upload", self) + logger.debug('%s: completed multipart upload', self) elif self._upload_id: # # AWS complains with "The XML you provided was not well-formed or @@ -851,9 +842,9 @@ def close(self): self._client.put_object( Bucket=self._bucket, Key=self._key, - Body=b"", + Body=b'', ) - logger.debug("%s: wrote 0 bytes to imitate multipart upload", self) + logger.debug('%s: wrote 0 bytes to imitate multipart upload', self) self._upload_id = None @property @@ -920,7 +911,7 @@ def to_boto3(self, resource): the same S3 object as this instance. Changes to the returned object will not affect the current instance. """ - assert resource, "resource must be a boto3.resource instance" + assert resource, 'resource must be a boto3.resource instance' return resource.Object(self._bucket, self._key) # @@ -933,7 +924,7 @@ def _upload_next_part(self): self, part_num, self._buf.tell(), - self._total_bytes / 1024.0**3, + self._total_bytes / 1024.0 ** 3, ) self._buf.seek(0) @@ -954,7 +945,7 @@ def _upload_next_part(self): ) ) - self._parts.append({"ETag": upload["ETag"], "PartNumber": part_num}) + self._parts.append({'ETag': upload['ETag'], 'PartNumber': part_num}) logger.debug("%s: upload of part_num #%i finished", self, part_num) self._total_parts += 1 @@ -1003,9 +994,7 @@ def __init__( try: self._client.head_bucket(Bucket=bucket) except botocore.client.ClientError as e: - raise ValueError( - "the bucket %r does not exist, or is forbidden for access" % bucket - ) from e + raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket) from e if writebuffer is None: self._buf = io.BytesIO() @@ -1039,9 +1028,7 @@ def close(self): ) except botocore.client.ClientError as e: raise ValueError( - "the bucket %r does not exist, or is forbidden for access" - % self._bucket - ) from e + 'the bucket %r does not exist, or is forbidden for access' % self._bucket) from e logger.debug("%s: direct upload finished", self) self._buf = None @@ -1083,8 +1070,7 @@ def write(self, b): interface implementation) into the buffer. Content of the buffer will be written to S3 on close as a single-part upload. - For more information about buffers, see https://docs.python.org/3/c-api/buffer.html - """ + For more information about buffers, see https://docs.python.org/3/c-api/buffer.html""" length = self._buf.write(b) self._total_bytes += length @@ -1107,36 +1093,32 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __str__(self): - return "smart_open.s3.SinglepartWriter(%r, %r)" % ( - self._object.bucket_name, - self._object.key, - ) + return "smart_open.s3.SinglepartWriter(%r, %r)" % (self._object.bucket_name, self._object.key) def __repr__(self): - return "smart_open.s3.SinglepartWriter(bucket=%r, key=%r)" % ( - self._bucket, - self._key, - ) + return "smart_open.s3.SinglepartWriter(bucket=%r, key=%r)" % (self._bucket, self._key) def _retry_if_failed( - partial, attempts=_UPLOAD_ATTEMPTS, sleep_seconds=_SLEEP_SECONDS, exceptions=None -): + partial, + attempts=_UPLOAD_ATTEMPTS, + sleep_seconds=_SLEEP_SECONDS, + exceptions=None): if exceptions is None: - exceptions = (botocore.exceptions.EndpointConnectionError,) + exceptions = (botocore.exceptions.EndpointConnectionError, ) for attempt in range(attempts): try: return partial() except exceptions: logger.critical( - "Unable to connect to the endpoint. Check your network connection. " - "Sleeping and retrying %d more times " - "before giving up." % (attempts - attempt - 1) + 'Unable to connect to the endpoint. Check your network connection. ' + 'Sleeping and retrying %d more times ' + 'before giving up.' % (attempts - attempt - 1) ) time.sleep(sleep_seconds) else: - logger.critical("Unable to connect to the endpoint. Giving up.") - raise IOError("Unable to connect to the endpoint after %d attempts" % attempts) + logger.critical('Unable to connect to the endpoint. Giving up.') + raise IOError('Unable to connect to the endpoint after %d attempts' % attempts) def _accept_all(key): @@ -1144,14 +1126,13 @@ def _accept_all(key): def iter_bucket( - bucket_name, - prefix="", - accept_key=None, - key_limit=None, - workers=16, - retries=3, - **session_kwargs -): + bucket_name, + prefix='', + accept_key=None, + key_limit=None, + workers=16, + retries=3, + **session_kwargs): """ Iterate and download all S3 objects under `s3://bucket_name/prefix`. @@ -1217,11 +1198,15 @@ def iter_bucket( total_size, key_no = 0, -1 key_iterator = _list_bucket( - bucket_name, prefix=prefix, accept_key=accept_key, **session_kwargs - ) + bucket_name, + prefix=prefix, + accept_key=accept_key, + **session_kwargs) download_key = functools.partial( - _download_key, bucket_name=bucket_name, retries=retries, **session_kwargs - ) + _download_key, + bucket_name=bucket_name, + retries=retries, + **session_kwargs) with smart_open.concurrency.create_pool(processes=workers) as pool: result_iterator = pool.imap_unordered(download_key, key_iterator) @@ -1232,10 +1217,7 @@ def iter_bucket( if key_no % 1000 == 0: logger.info( "yielding key #%i: %s, size %i (total %.1fMB)", - key_no, - key, - len(content), - total_size / 1024.0**2, + key_no, key, len(content), total_size / 1024.0 ** 2 ) yield key, content total_size += len(content) @@ -1248,10 +1230,7 @@ def iter_bucket( # after we listed the contents of the bucket, but before we # downloaded the object. # - if not ( - "Error" in err.response - and err.response["Error"].get("Code") == "404" - ): + if not ('Error' in err.response and err.response['Error'].get('Code') == '404'): raise err except StopIteration: break @@ -1259,9 +1238,13 @@ def iter_bucket( logger.info("processed %i keys, total size %i" % (key_no + 1, total_size)) -def _list_bucket(bucket_name, prefix="", accept_key=lambda k: True, **session_kwargs): +def _list_bucket( + bucket_name, + prefix='', + accept_key=lambda k: True, + **session_kwargs): session = boto3.session.Session(**session_kwargs) - client = session.client("s3") + client = session.client('s3') ctoken = None while True: @@ -1273,28 +1256,28 @@ def _list_bucket(bucket_name, prefix="", accept_key=lambda k: True, **session_kw kwargs = dict(Bucket=bucket_name, Prefix=prefix) response = client.list_objects_v2(**kwargs) try: - content = response["Contents"] + content = response['Contents'] except KeyError: pass else: for c in content: - key = c["Key"] + key = c['Key'] if accept_key(key): yield key - ctoken = response.get("NextContinuationToken", None) + ctoken = response.get('NextContinuationToken', None) if not ctoken: break def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs): if bucket_name is None: - raise ValueError("bucket_name may not be None") + raise ValueError('bucket_name may not be None') # # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/resources.html#multithreading-or-multiprocessing-with-resources # session = boto3.session.Session(**session_kwargs) - s3 = session.resource("s3") + s3 = session.resource('s3') bucket = s3.Bucket(bucket_name) # Sometimes, https://github.com/boto/boto/issues/2409 can happen From a80e031f669f7bcd970209cafc95146d6e5467c0 Mon Sep 17 00:00:00 2001 From: Brian Beck Date: Wed, 6 Sep 2023 18:05:23 -0600 Subject: [PATCH 11/12] fix whitespace --- smart_open/s3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index 135dde15..18955fb3 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -682,7 +682,7 @@ def seek(self, offset, whence=constants.WHENCE_START): whence == constants.WHENCE_START and offset == self._current_pos ): self._current_pos = self._raw_reader.seek(offset, whence) - + self._buffer.empty() self._eof = self._current_pos == self._raw_reader._content_length From 02609be793c3401d28aeb1ac327426efe487a75a Mon Sep 17 00:00:00 2001 From: Brian Beck Date: Wed, 6 Sep 2023 18:09:00 -0600 Subject: [PATCH 12/12] undo formatting --- smart_open/tests/test_s3_version.py | 56 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/smart_open/tests/test_s3_version.py b/smart_open/tests/test_s3_version.py index b566a1fc..6f9584d0 100644 --- a/smart_open/tests/test_s3_version.py +++ b/smart_open/tests/test_s3_version.py @@ -11,14 +11,14 @@ from smart_open import open -BUCKET_NAME = "test-smartopen" -KEY_NAME = "test-key" +BUCKET_NAME = 'test-smartopen' +KEY_NAME = 'test-key' logger = logging.getLogger(__name__) -_resource = functools.partial(boto3.resource, region_name="us-east-1") +_resource = functools.partial(boto3.resource, region_name='us-east-1') def get_versions(bucket, key): @@ -26,7 +26,7 @@ def get_versions(bucket, key): return [ v.id for v in sorted( - _resource("s3").Bucket(bucket).object_versions.filter(Prefix=key), + _resource('s3').Bucket(bucket).object_versions.filter(Prefix=key), key=lambda version: version.last_modified, ) ] @@ -39,80 +39,76 @@ def setUp(self): # Each run of this test reuses the BUCKET_NAME, but works with a # different key for isolation. # - resource = _resource("s3") + resource = _resource('s3') resource.create_bucket(Bucket=BUCKET_NAME).wait_until_exists() resource.BucketVersioning(BUCKET_NAME).enable() - self.key = "test-write-key-{}".format(uuid.uuid4().hex) + self.key = 'test-write-key-{}'.format(uuid.uuid4().hex) self.url = "s3://%s/%s" % (BUCKET_NAME, self.key) - self.test_ver1 = "String version 1.0".encode("utf8") - self.test_ver2 = "String version 2.0".encode("utf8") + self.test_ver1 = u"String version 1.0".encode('utf8') + self.test_ver2 = u"String version 2.0".encode('utf8') bucket = resource.Bucket(BUCKET_NAME) bucket.put_object(Key=self.key, Body=self.test_ver1) - logging.critical( - "versions after first write: %r", get_versions(BUCKET_NAME, self.key) - ) + logging.critical('versions after first write: %r', get_versions(BUCKET_NAME, self.key)) time.sleep(3) bucket.put_object(Key=self.key, Body=self.test_ver2) self.versions = get_versions(BUCKET_NAME, self.key) - logging.critical( - "versions after second write: %r", get_versions(BUCKET_NAME, self.key) - ) + logging.critical('versions after second write: %r', get_versions(BUCKET_NAME, self.key)) assert len(self.versions) == 2 def test_good_id(self): """Does passing the version_id parameter into the s3 submodule work correctly when reading?""" - params = {"version_id": self.versions[0]} - with open(self.url, mode="rb", transport_params=params) as fin: + params = {'version_id': self.versions[0]} + with open(self.url, mode='rb', transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver1) def test_bad_id(self): """Does passing an invalid version_id exception into the s3 submodule get handled correctly?""" - params = {"version_id": "bad-version-does-not-exist"} + params = {'version_id': 'bad-version-does-not-exist'} with self.assertRaises(IOError): - open(self.url, "rb", transport_params=params) + open(self.url, 'rb', transport_params=params) def test_bad_mode(self): """Do we correctly handle non-None version when writing?""" - params = {"version_id": self.versions[0]} + params = {'version_id': self.versions[0]} with self.assertRaises(ValueError): - open(self.url, "wb", transport_params=params) + open(self.url, 'wb', transport_params=params) def test_no_version(self): """Passing in no version at all gives the newest version of the file?""" - with open(self.url, "rb") as fin: + with open(self.url, 'rb') as fin: actual = fin.read() self.assertEqual(actual, self.test_ver2) def test_newest_version(self): """Passing in the newest version explicitly gives the most recent content?""" - params = {"version_id": self.versions[1]} - with open(self.url, mode="rb", transport_params=params) as fin: + params = {'version_id': self.versions[1]} + with open(self.url, mode='rb', transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver2) def test_oldest_version(self): """Passing in the oldest version gives the oldest content?""" - params = {"version_id": self.versions[0]} - with open(self.url, mode="rb", transport_params=params) as fin: + params = {'version_id': self.versions[0]} + with open(self.url, mode='rb', transport_params=params) as fin: actual = fin.read() self.assertEqual(actual, self.test_ver1) def test_version_to_boto3(self): """Passing in the oldest version gives the oldest content?""" self.versions = get_versions(BUCKET_NAME, self.key) - params = {"version_id": self.versions[0]} - with open(self.url, mode="rb", transport_params=params) as fin: - returned_obj = fin.to_boto3(_resource("s3")) + params = {'version_id': self.versions[0]} + with open(self.url, mode='rb', transport_params=params) as fin: + returned_obj = fin.to_boto3(_resource('s3')) - boto3_body = boto3_body = returned_obj.get()["Body"].read() + boto3_body = boto3_body = returned_obj.get()['Body'].read() self.assertEqual(boto3_body, self.test_ver1) -if __name__ == "__main__": +if __name__ == '__main__': unittest.main()