diff --git a/s3fs/core.py b/s3fs/core.py index 068d2961..60662ba7 100644 --- a/s3fs/core.py +++ b/s3fs/core.py @@ -31,30 +31,6 @@ _VALID_FILE_MODES = {'r', 'w', 'a', 'rb', 'wb', 'ab'} - -def split_path(path): - """ - Normalise S3 path string into bucket and key. - - Parameters - ---------- - path : string - Input path, like `s3://mybucket/path/to/file` - - Examples - -------- - >>> split_path("s3://mybucket/path/to/file") - ['mybucket', 'path/to/file'] - """ - if path.startswith('s3://'): - path = path[5:] - path = path.rstrip('/').lstrip('/') - if '/' not in path: - return path, "" - else: - return path.split('/', 1) - - key_acls = {'private', 'public-read', 'public-read-write', 'authenticated-read', 'aws-exec-read', 'bucket-owner-read', 'bucket-owner-full-control'} @@ -127,7 +103,7 @@ class S3FileSystem(AbstractFileSystem): connect_timeout = 5 read_timeout = 15 default_block_size = 5 * 2**20 - protocol = 's3' + protocol = ['s3', 's3a'] _extra_tokenize_attributes = ('default_block_size',) def __init__(self, anon=False, key=None, secret=None, token=None, @@ -192,6 +168,27 @@ def _get_s3_method_kwargs(self, method, *akwarglist, **kwargs): # filter all kwargs return self._filter_kwargs(method, additional_kwargs) + def split_path(self, path): + """ + Normalise S3 path string into bucket and key. + + Parameters + ---------- + path : string + Input path, like `s3://mybucket/path/to/file` + + Examples + -------- + >>> split_path("s3://mybucket/path/to/file") + ['mybucket', 'path/to/file'] + """ + path = self._strip_protocol(path) + path = path.lstrip('/') + if '/' not in path: + return path, "" + else: + return path.split('/', 1) + def connect(self, refresh=True): """ Establish S3 connection object. @@ -315,10 +312,7 @@ def _open(self, path, mode='rb', block_size=None, acl='', version_id=None, autocommit=autocommit, requester_pays=requester_pays) def _lsdir(self, path, refresh=False, max_items=None): - if path.startswith('s3://'): - path = path[len('s3://'):] - path = path.rstrip('/') - bucket, prefix = split_path(path) + bucket, prefix = self.split_path(path) prefix = prefix + '/' if prefix else "" if path not in self.dircache or refresh: try: @@ -418,8 +412,7 @@ def _ls(self, path, refresh=False): refresh : bool (=False) if False, look in local cache for file details first """ - if path.startswith('s3://'): - path = path[len('s3://'):] + path = self._strip_protocol(path) if path in ['', '/']: return self._lsbuckets(refresh) else: @@ -429,7 +422,7 @@ def exists(self, path): if path in ['', '/']: # the root always exists, even if anon return True - bucket, key = split_path(path) + bucket, key = self.split_path(path) if key: return super().exists(path) else: @@ -441,7 +434,7 @@ def exists(self, path): def touch(self, path, truncate=True, data=None, **kwargs): """Create empty file or truncate""" - bucket, key = split_path(path) + bucket, key = self.split_path(path) if not truncate and self.exists(path): raise ValueError("S3 does not support touching existent files") try: @@ -461,7 +454,7 @@ def info(self, path, version_id=None): kwargs['VersionId'] = version_id if self.version_aware: try: - bucket, key = split_path(path) + bucket, key = self.split_path(path) out = self._call_s3(self.s3.head_object, kwargs, Bucket=bucket, Key=key, **self.req_kw) return { @@ -540,7 +533,7 @@ def object_version_info(self, path, **kwargs): if not self.version_aware: raise ValueError("version specific functionality is disabled for " "non-version aware filesystems") - bucket, key = split_path(path) + bucket, key = self.split_path(path) kwargs = {} out = {'IsTruncated': True} versions = [] @@ -565,7 +558,7 @@ def metadata(self, path, refresh=False, **kwargs): refresh : bool (=False) if False, look in local cache for file metadata first """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) if refresh or path not in self._metadata_cache: response = self._call_s3(self.s3.head_object, @@ -584,7 +577,7 @@ def get_tags(self, path): ------- {str: str} """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) response = self._call_s3(self.s3.get_object_tagging, Bucket=bucket, Key=key) return {v['Key']: v['Value'] for v in response['TagSet']} @@ -610,7 +603,7 @@ def put_tags(self, path, tags, mode='o'): 'm': Will merge in new tags with existing tags. Incurs two remote calls. """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) if mode == 'm': existing_tags = self.get_tags(path=path) @@ -667,7 +660,7 @@ def setxattr(self, path, copy_kwargs=None, **kw_args): http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html#object-metadata """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) metadata = self.metadata(path) metadata.update(**kw_args) copy_kwargs = copy_kwargs or {} @@ -702,7 +695,7 @@ def chmod(self, path, acl, **kwargs): acl : string the value of ACL to apply """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) if key: if acl not in key_acls: raise ValueError('ACL not in %s', key_acls) @@ -724,7 +717,7 @@ def url(self, path, expires=3600, **kwargs): expires : int the number of seconds this signature will be good for. """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) return self.s3.generate_presigned_url( ClientMethod='get_object', Params=dict(Bucket=bucket, Key=key, **kwargs), @@ -743,7 +736,7 @@ def merge(self, path, filelist, **kwargs): filelist : list of str The paths, in order, to assemble into the final file. """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) mpu = self._call_s3( self.s3.create_multipart_upload, kwargs, @@ -766,8 +759,8 @@ def merge(self, path, filelist, **kwargs): def copy_basic(self, path1, path2, **kwargs): """ Copy file between locations on S3 """ - buc1, key1 = split_path(path1) - buc2, key2 = split_path(path2) + buc1, key1 = self.split_path(path1) + buc2, key2 = self.split_path(path2) try: self._call_s3( self.s3.copy_object, @@ -780,8 +773,8 @@ def copy_basic(self, path1, path2, **kwargs): raise ValueError('Copy failed (%r -> %r): %s' % (path1, path2, e)) def copy_managed(self, path1, path2, **kwargs): - buc1, key1 = split_path(path1) - buc2, key2 = split_path(path2) + buc1, key1 = self.split_path(path1) + buc2, key2 = self.split_path(path2) copy_source = { 'Bucket': buc1, 'Key': key1 @@ -816,7 +809,7 @@ def bulk_delete(self, pathlist, **kwargs): """ if not pathlist: return - buckets = {split_path(path)[0] for path in pathlist} + buckets = {self.split_path(path)[0] for path in pathlist} if len(buckets) > 1: raise ValueError("Bulk delete files should refer to only one " "bucket") @@ -825,7 +818,7 @@ def bulk_delete(self, pathlist, **kwargs): for i in range((len(pathlist) // 1000) + 1): self.bulk_delete(pathlist[i * 1000:(i + 1) * 1000]) return - delete_keys = {'Objects': [{'Key': split_path(path)[1]} for path + delete_keys = {'Objects': [{'Key': self.split_path(path)[1]} for path in pathlist]} for path in pathlist: self.invalidate_cache(self._parent(path)) @@ -849,7 +842,7 @@ def rm(self, path, recursive=False, **kwargs): Whether to remove also all entries below, i.e., which are returned by `walk()`. """ - bucket, key = split_path(path) + bucket, key = self.split_path(path) if recursive: files = self.find(path, maxdepth=None) if key and not files: @@ -887,7 +880,7 @@ def invalidate_cache(self, path=None): self.dircache.pop(self._parent(path), None) def walk(self, path, maxdepth=None, **kwargs): - if path in ['', '*', 's3://']: + if path in ['', '*'] + [f'{p}://' for p in self.protocol]: raise ValueError('Cannot crawl all of S3') return super().walk(path, maxdepth=maxdepth, **kwargs) @@ -939,7 +932,7 @@ class S3File(AbstractBufferedFile): def __init__(self, s3, path, mode='rb', block_size=5 * 2 ** 20, acl="", version_id=None, fill_cache=True, s3_additional_kwargs=None, autocommit=True, cache_type='bytes', requester_pays=False): - bucket, key = split_path(path) + bucket, key = s3.split_path(path) if not key: raise ValueError('Attempt to open non key-like path: %s' % path) self.bucket = bucket @@ -1060,7 +1053,7 @@ def _fetch_range(self, start, end): return _fetch_range(self.fs.s3, self.bucket, self.key, self.version_id, start, end, req_kw=self.req_kw) def _upload_chunk(self, final=False): - bucket, key = split_path(self.path) + bucket, key = self.s3.split_path(self.path) logger.debug("Upload for %s, final=%s, loc=%s, buffer loc=%s" % ( self, final, self.loc, self.buffer.tell() ))