Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed May 2, 2024
1 parent f9254c9 commit db396a2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
85 changes: 49 additions & 36 deletions pyathena/filesystem/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from concurrent.futures import Future, as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Pattern, Tuple, Union, cast

Expand Down Expand Up @@ -240,7 +241,7 @@ def _ls_dirs(
path: str,
prefix: str = "",
delimiter: str = "/",
next_token=None,
next_token: Optional[str] = None,
max_keys: Optional[int] = None,
refresh: bool = False,
) -> List[S3Object]:
Expand Down Expand Up @@ -297,7 +298,9 @@ def _ls_dirs(
files = self.dircache[path]
return files

def ls(self, path, detail=False, refresh=False, **kwargs):
def ls(
self, path: str, detail: bool = False, refresh: bool = False, **kwargs
) -> Union[List[S3Object], List[str]]:
path = self._strip_protocol(path).rstrip("/")
if path in ["", "/"]:
files = self._ls_buckets(refresh)
Expand All @@ -309,7 +312,7 @@ def ls(self, path, detail=False, refresh=False, **kwargs):
files = [file]
return [f for f in files] if detail else [f.name for f in files]

def info(self, path, **kwargs) -> S3Object:
def info(self, path: str, **kwargs) -> S3Object:
refresh = kwargs.pop("refresh", False)
path = self._strip_protocol(path)
bucket, key, path_version_id = self.parse_path(path)
Expand Down Expand Up @@ -387,7 +390,15 @@ def info(self, path, **kwargs) -> S3Object:
else:
raise FileNotFoundError(path)

def find(self, path, maxdepth=None, withdirs=None, detail=False, **kwargs):
def find(
self,
path: str,
maxdepth: Optional[int] = None,
withdirs: Optional[bool] = None,
detail: bool = False,
**kwargs,
) -> Union[Dict[str, S3Object], List[str]]:
# TODO: Support maxdepth and withdirs
path = self._strip_protocol(path)
if path in ["", "/"]:
raise ValueError("Cannot traverse all files in S3.")
Expand All @@ -405,7 +416,7 @@ def find(self, path, maxdepth=None, withdirs=None, detail=False, **kwargs):
else:
return [f.name for f in files]

def exists(self, path, **kwargs):
def exists(self, path: str, **kwargs) -> bool:
path = self._strip_protocol(path)
if path in ["", "/"]:
# The root always exists.
Expand Down Expand Up @@ -436,19 +447,19 @@ def exists(self, path, **kwargs):
else:
return False

def rm_file(self, path):
def rm_file(self, path: str) -> None:
# TODO
raise NotImplementedError # pragma: no cover

def rmdir(self, path):
def rmdir(self, path: str) -> None:
# TODO
raise NotImplementedError # pragma: no cover

def _rm(self, path):
def _rm(self, path: str) -> None:
# TODO
raise NotImplementedError # pragma: no cover

def touch(self, path, truncate=True, **kwargs):
def touch(self, path: str, truncate: bool = True, **kwargs) -> Dict[str, Any]:
bucket, key, version_id = self.parse_path(path)
if version_id:
raise ValueError("Cannot touch the file with the version specified.")
Expand All @@ -461,11 +472,13 @@ def touch(self, path, truncate=True, **kwargs):
self.invalidate_cache(path)
return object_.to_dict()

def cp_file(self, path1, path2, **kwargs):
def cp_file(self, path1: str, path2: str, **kwargs):
# TODO
raise NotImplementedError # pragma: no cover

def cat_file(self, path, start=None, end=None, **kwargs):
def cat_file(
self, path: str, start: Optional[int] = None, end: Optional[int] = None, **kwargs
) -> bytes:
bucket, key, version_id = self.parse_path(path)
if start is not None and end is not None:
ranges = (start, end)
Expand All @@ -474,40 +487,40 @@ def cat_file(self, path, start=None, end=None, **kwargs):

return self._get_object(
bucket=bucket,
key=key,
key=cast(str, key),
ranges=ranges,
version_id=version_id,
**kwargs,
)[1]

def pipe_file(self, path, value, **kwargs):
def pipe_file(self, path: str, value, **kwargs):
# TODO
raise NotImplementedError # pragma: no cover

def put_file(self, lpath, rpath, callback=_DEFAULT_CALLBACK, **kwargs):
def put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs):
# TODO
raise NotImplementedError # pragma: no cover

def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs):
def get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs):
# TODO
raise NotImplementedError # pragma: no cover

def checksum(self, path):
def checksum(self, path: str):
# TODO
raise NotImplementedError # pragma: no cover

def sign(self, path, expiration=100, **kwargs):
def sign(self, path: str, expiration: int = 100, **kwargs):
# TODO
raise NotImplementedError # pragma: no cover

def created(self, path):
def created(self, path: str) -> datetime:
return self.modified(path)

def modified(self, path):
def modified(self, path: str) -> datetime:
info = self.info(path)
return info.get("last_modified")
return cast(datetime, info.get("last_modified"))

def invalidate_cache(self, path=None):
def invalidate_cache(self, path: Optional[str] = None) -> None:
if path is None:
self.dircache.clear()
else:
Expand All @@ -518,14 +531,14 @@ def invalidate_cache(self, path=None):

def _open(
self,
path,
mode="rb",
block_size=None,
cache_type=None,
autocommit=True,
cache_options=None,
path: str,
mode: str = "rb",
block_size: Optional[int] = None,
cache_type: Optional[str] = None,
autocommit: bool = True,
cache_options: Optional[Dict[Any, Any]] = None,
**kwargs,
):
) -> S3File:
if block_size is None:
block_size = self.default_block_size
if cache_type is None:
Expand Down Expand Up @@ -558,7 +571,7 @@ def _get_object(
range_ = f"bytes={ranges[0]}-{ranges[1] - 1}"
request.update({"Range": range_})
else:
ranges = (0, None)
ranges = (0, 0)
range_ = "bytes=0-"
if version_id:
request.update({"VersionId": version_id})
Expand Down Expand Up @@ -692,8 +705,8 @@ def __init__(
autocommit: bool = True,
cache_options: Optional[Dict[Any, Any]] = None,
size: Optional[int] = None,
s3_additional_kwargs=None,
):
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.max_workers = max_workers
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {}
Expand Down Expand Up @@ -750,11 +763,11 @@ def __init__(
self.multipart_upload: Optional[S3MultipartUpload] = None
self.multipart_upload_parts: List[Future[S3MultipartUploadPart]] = []

def close(self):
def close(self) -> None:
super().close()
self._executor.shutdown()

def _initiate_upload(self):
def _initiate_upload(self) -> None:
if self.tell() < self.blocksize:
# Files smaller than block size in size cannot be multipart uploaded.
return
Expand All @@ -776,7 +789,7 @@ def _initiate_upload(self):
)
)

def _upload_chunk(self, final=False):
def _upload_chunk(self, final: bool = False) -> bool:
if self.tell() < self.blocksize:
# Files smaller than block size in size cannot be multipart uploaded.
if self.autocommit and final:
Expand Down Expand Up @@ -805,7 +818,7 @@ def _upload_chunk(self, final=False):
self.commit()
return True

def commit(self):
def commit(self) -> None:
if self.tell() == 0:
if self.buffer is not None:
self.discard()
Expand Down Expand Up @@ -845,7 +858,7 @@ def commit(self):

self.fs.invalidate_cache(self.path)

def discard(self):
def discard(self) -> None:
if self.multipart_upload:
for f in self.multipart_upload_parts:
f.cancel()
Expand Down
35 changes: 18 additions & 17 deletions pyathena/filesystem/s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, Iterator, MutableMapping, Optional

_logger = logging.getLogger(__name__) # type: ignore

_API_FIELD_TO_S3_OBJECT_PROPERTY = {
"ETag": "etag",
"CacheControl": "cache_control",
Expand Down Expand Up @@ -36,26 +37,26 @@ class S3ObjectType:


class S3StorageClass:
S3_STORAGE_CLASS_STANDARD = "STANDARD"
S3_STORAGE_CLASS_REDUCED_REDUNDANCY = "REDUCED_REDUNDANCY"
S3_STORAGE_CLASS_STANDARD_IA = "STANDARD_IA"
S3_STORAGE_CLASS_ONEZONE_IA = "ONEZONE_IA"
S3_STORAGE_CLASS_INTELLIGENT_TIERING = "INTELLIGENT_TIERING"
S3_STORAGE_CLASS_GLACIER = "GLACIER"
S3_STORAGE_CLASS_DEEP_ARCHIVE = "DEEP_ARCHIVE"
S3_STORAGE_CLASS_OUTPOSTS = "OUTPOSTS"
S3_STORAGE_CLASS_GLACIER_IR = "GLACIER_IR"
S3_STORAGE_CLASS_STANDARD: str = "STANDARD"
S3_STORAGE_CLASS_REDUCED_REDUNDANCY: str = "REDUCED_REDUNDANCY"
S3_STORAGE_CLASS_STANDARD_IA: str = "STANDARD_IA"
S3_STORAGE_CLASS_ONEZONE_IA: str = "ONEZONE_IA"
S3_STORAGE_CLASS_INTELLIGENT_TIERING: str = "INTELLIGENT_TIERING"
S3_STORAGE_CLASS_GLACIER: str = "GLACIER"
S3_STORAGE_CLASS_DEEP_ARCHIVE: str = "DEEP_ARCHIVE"
S3_STORAGE_CLASS_OUTPOSTS: str = "OUTPOSTS"
S3_STORAGE_CLASS_GLACIER_IR: str = "GLACIER_IR"

S3_STORAGE_CLASS_BUCKET = "BUCKET"
S3_STORAGE_CLASS_DIRECTORY = "DIRECTORY"
S3_STORAGE_CLASS_BUCKET: str = "BUCKET"
S3_STORAGE_CLASS_DIRECTORY: str = "DIRECTORY"


class S3Object(MutableMapping[str, Any]):
def __init__(
self,
init: Dict[str, Any],
**kwargs,
):
) -> None:
if init:
super().update({_API_FIELD_TO_S3_OBJECT_PROPERTY.get(k, k): v for k, v in init.items()})
if "Size" in init:
Expand Down Expand Up @@ -109,7 +110,7 @@ def to_api_repr(self) -> Dict[str, Any]:


class S3PutObject:
def __init__(self, response: Dict[str, Any]):
def __init__(self, response: Dict[str, Any]) -> None:
self._expiration: Optional[str] = response.get("Expiration")
self._version_id: Optional[str] = response.get("VersionId")
self._etag: Optional[str] = response.get("ETag")
Expand Down Expand Up @@ -181,12 +182,12 @@ def bucket_key_enabled(self) -> Optional[bool]:
def request_charged(self) -> Optional[str]:
return self._request_charged

def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
return copy.deepcopy(self.__dict__)


class S3MultipartUpload:
def __init__(self, response: Dict[str, Any]):
def __init__(self, response: Dict[str, Any]) -> None:
self._abort_date = response.get("AbortDate")
self._abort_rule_id = response.get("AbortRuleId")
self._bucket = response.get("Bucket")
Expand Down Expand Up @@ -255,7 +256,7 @@ def checksum_algorithm(self) -> Optional[str]:


class S3MultipartUploadPart:
def __init__(self, part_number: int, response: Dict[str, Any]):
def __init__(self, part_number: int, response: Dict[str, Any]) -> None:
self._part_number = part_number
self._copy_source_version_id: Optional[str] = response.get("CopySourceVersionId")
copy_part_result = response.get("CopyPartResult")
Expand Down Expand Up @@ -348,7 +349,7 @@ def to_api_repr(self) -> Dict[str, Any]:


class S3CompleteMultipartUpload:
def __init__(self, response: Dict[str, Any]):
def __init__(self, response: Dict[str, Any]) -> None:
self._location: Optional[str] = response.get("Location")
self._bucket: Optional[str] = response.get("Bucket")
self._key: Optional[str] = response.get("Key")
Expand Down

0 comments on commit db396a2

Please sign in to comment.