Skip to content

Commit

Permalink
Merge pull request #269 from criteo-forks/add_more_fs_schemes
Browse files Browse the repository at this point in the history
Add more s3 fs schemes (s3a://, s3n://)
  • Loading branch information
martindurant authored Nov 28, 2019
2 parents 40a79bb + abcbcbb commit 87e5149
Showing 1 changed file with 45 additions and 52 deletions.
97 changes: 45 additions & 52 deletions s3fs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,6 @@

_VALID_FILE_MODES = {'r', 'w', 'a', 'rb', 'wb', 'ab'}


def split_path(path):
"""
Normalise S3 path string into bucket and key.
Parameters
----------
path : string
Input path, like `s3://mybucket/path/to/file`
Examples
--------
>>> split_path("s3://mybucket/path/to/file")
['mybucket', 'path/to/file']
"""
if path.startswith('s3://'):
path = path[5:]
path = path.rstrip('/').lstrip('/')
if '/' not in path:
return path, ""
else:
return path.split('/', 1)


key_acls = {'private', 'public-read', 'public-read-write',
'authenticated-read', 'aws-exec-read', 'bucket-owner-read',
'bucket-owner-full-control'}
Expand Down Expand Up @@ -127,7 +103,7 @@ class S3FileSystem(AbstractFileSystem):
connect_timeout = 5
read_timeout = 15
default_block_size = 5 * 2**20
protocol = 's3'
protocol = ['s3', 's3a']
_extra_tokenize_attributes = ('default_block_size',)

def __init__(self, anon=False, key=None, secret=None, token=None,
Expand Down Expand Up @@ -192,6 +168,27 @@ def _get_s3_method_kwargs(self, method, *akwarglist, **kwargs):
# filter all kwargs
return self._filter_kwargs(method, additional_kwargs)

def split_path(self, path):
"""
Normalise S3 path string into bucket and key.
Parameters
----------
path : string
Input path, like `s3://mybucket/path/to/file`
Examples
--------
>>> split_path("s3://mybucket/path/to/file")
['mybucket', 'path/to/file']
"""
path = self._strip_protocol(path)
path = path.lstrip('/')
if '/' not in path:
return path, ""
else:
return path.split('/', 1)

def connect(self, refresh=True):
"""
Establish S3 connection object.
Expand Down Expand Up @@ -315,10 +312,7 @@ def _open(self, path, mode='rb', block_size=None, acl='', version_id=None,
autocommit=autocommit, requester_pays=requester_pays)

def _lsdir(self, path, refresh=False, max_items=None):
if path.startswith('s3://'):
path = path[len('s3://'):]
path = path.rstrip('/')
bucket, prefix = split_path(path)
bucket, prefix = self.split_path(path)
prefix = prefix + '/' if prefix else ""
if path not in self.dircache or refresh:
try:
Expand Down Expand Up @@ -418,8 +412,7 @@ def _ls(self, path, refresh=False):
refresh : bool (=False)
if False, look in local cache for file details first
"""
if path.startswith('s3://'):
path = path[len('s3://'):]
path = self._strip_protocol(path)
if path in ['', '/']:
return self._lsbuckets(refresh)
else:
Expand All @@ -429,7 +422,7 @@ def exists(self, path):
if path in ['', '/']:
# the root always exists, even if anon
return True
bucket, key = split_path(path)
bucket, key = self.split_path(path)
if key:
return super().exists(path)
else:
Expand All @@ -441,7 +434,7 @@ def exists(self, path):

def touch(self, path, truncate=True, data=None, **kwargs):
"""Create empty file or truncate"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
if not truncate and self.exists(path):
raise ValueError("S3 does not support touching existent files")
try:
Expand All @@ -461,7 +454,7 @@ def info(self, path, version_id=None):
kwargs['VersionId'] = version_id
if self.version_aware:
try:
bucket, key = split_path(path)
bucket, key = self.split_path(path)
out = self._call_s3(self.s3.head_object, kwargs, Bucket=bucket,
Key=key, **self.req_kw)
return {
Expand Down Expand Up @@ -540,7 +533,7 @@ def object_version_info(self, path, **kwargs):
if not self.version_aware:
raise ValueError("version specific functionality is disabled for "
"non-version aware filesystems")
bucket, key = split_path(path)
bucket, key = self.split_path(path)
kwargs = {}
out = {'IsTruncated': True}
versions = []
Expand All @@ -565,7 +558,7 @@ def metadata(self, path, refresh=False, **kwargs):
refresh : bool (=False)
if False, look in local cache for file metadata first
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)

if refresh or path not in self._metadata_cache:
response = self._call_s3(self.s3.head_object,
Expand All @@ -584,7 +577,7 @@ def get_tags(self, path):
-------
{str: str}
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
response = self._call_s3(self.s3.get_object_tagging,
Bucket=bucket, Key=key)
return {v['Key']: v['Value'] for v in response['TagSet']}
Expand All @@ -610,7 +603,7 @@ def put_tags(self, path, tags, mode='o'):
'm': Will merge in new tags with existing tags. Incurs two remote
calls.
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)

if mode == 'm':
existing_tags = self.get_tags(path=path)
Expand Down Expand Up @@ -667,7 +660,7 @@ def setxattr(self, path, copy_kwargs=None, **kw_args):
http://docs.aws.amazon.com/AmazonS3/latest/dev/UsingMetadata.html#object-metadata
"""

bucket, key = split_path(path)
bucket, key = self.split_path(path)
metadata = self.metadata(path)
metadata.update(**kw_args)
copy_kwargs = copy_kwargs or {}
Expand Down Expand Up @@ -702,7 +695,7 @@ def chmod(self, path, acl, **kwargs):
acl : string
the value of ACL to apply
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
if key:
if acl not in key_acls:
raise ValueError('ACL not in %s', key_acls)
Expand All @@ -724,7 +717,7 @@ def url(self, path, expires=3600, **kwargs):
expires : int
the number of seconds this signature will be good for.
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
return self.s3.generate_presigned_url(
ClientMethod='get_object',
Params=dict(Bucket=bucket, Key=key, **kwargs),
Expand All @@ -743,7 +736,7 @@ def merge(self, path, filelist, **kwargs):
filelist : list of str
The paths, in order, to assemble into the final file.
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
mpu = self._call_s3(
self.s3.create_multipart_upload,
kwargs,
Expand All @@ -766,8 +759,8 @@ def merge(self, path, filelist, **kwargs):

def copy_basic(self, path1, path2, **kwargs):
""" Copy file between locations on S3 """
buc1, key1 = split_path(path1)
buc2, key2 = split_path(path2)
buc1, key1 = self.split_path(path1)
buc2, key2 = self.split_path(path2)
try:
self._call_s3(
self.s3.copy_object,
Expand All @@ -780,8 +773,8 @@ def copy_basic(self, path1, path2, **kwargs):
raise ValueError('Copy failed (%r -> %r): %s' % (path1, path2, e))

def copy_managed(self, path1, path2, **kwargs):
buc1, key1 = split_path(path1)
buc2, key2 = split_path(path2)
buc1, key1 = self.split_path(path1)
buc2, key2 = self.split_path(path2)
copy_source = {
'Bucket': buc1,
'Key': key1
Expand Down Expand Up @@ -816,7 +809,7 @@ def bulk_delete(self, pathlist, **kwargs):
"""
if not pathlist:
return
buckets = {split_path(path)[0] for path in pathlist}
buckets = {self.split_path(path)[0] for path in pathlist}
if len(buckets) > 1:
raise ValueError("Bulk delete files should refer to only one "
"bucket")
Expand All @@ -825,7 +818,7 @@ def bulk_delete(self, pathlist, **kwargs):
for i in range((len(pathlist) // 1000) + 1):
self.bulk_delete(pathlist[i * 1000:(i + 1) * 1000])
return
delete_keys = {'Objects': [{'Key': split_path(path)[1]} for path
delete_keys = {'Objects': [{'Key': self.split_path(path)[1]} for path
in pathlist]}
for path in pathlist:
self.invalidate_cache(self._parent(path))
Expand All @@ -849,7 +842,7 @@ def rm(self, path, recursive=False, **kwargs):
Whether to remove also all entries below, i.e., which are returned
by `walk()`.
"""
bucket, key = split_path(path)
bucket, key = self.split_path(path)
if recursive:
files = self.find(path, maxdepth=None)
if key and not files:
Expand Down Expand Up @@ -887,7 +880,7 @@ def invalidate_cache(self, path=None):
self.dircache.pop(self._parent(path), None)

def walk(self, path, maxdepth=None, **kwargs):
if path in ['', '*', 's3://']:
if path in ['', '*'] + [f'{p}://' for p in self.protocol]:
raise ValueError('Cannot crawl all of S3')
return super().walk(path, maxdepth=maxdepth, **kwargs)

Expand Down Expand Up @@ -939,7 +932,7 @@ class S3File(AbstractBufferedFile):
def __init__(self, s3, path, mode='rb', block_size=5 * 2 ** 20, acl="",
version_id=None, fill_cache=True, s3_additional_kwargs=None,
autocommit=True, cache_type='bytes', requester_pays=False):
bucket, key = split_path(path)
bucket, key = s3.split_path(path)
if not key:
raise ValueError('Attempt to open non key-like path: %s' % path)
self.bucket = bucket
Expand Down Expand Up @@ -1060,7 +1053,7 @@ def _fetch_range(self, start, end):
return _fetch_range(self.fs.s3, self.bucket, self.key, self.version_id, start, end, req_kw=self.req_kw)

def _upload_chunk(self, final=False):
bucket, key = split_path(self.path)
bucket, key = self.s3.split_path(self.path)
logger.debug("Upload for %s, final=%s, loc=%s, buffer loc=%s" % (
self, final, self.loc, self.buffer.tell()
))
Expand Down

0 comments on commit 87e5149

Please sign in to comment.