From e893f86d97832535fcfdd78bd678f8af276b2867 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Wed, 1 May 2024 00:15:56 +0900 Subject: [PATCH] WIP --- pyathena/filesystem/s3.py | 510 ++++++++++++++++++++++++++----- pyathena/filesystem/s3_object.py | 391 +++++++++++++++++++++++- 2 files changed, 804 insertions(+), 97 deletions(-) diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 25c14899..184d5a4c 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -4,20 +4,30 @@ import itertools import logging import re +from concurrent.futures import Future, as_completed from concurrent.futures.thread import ThreadPoolExecutor from copy import deepcopy from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, Union, cast import botocore.exceptions from boto3 import Session from botocore import UNSIGNED from botocore.client import BaseClient, Config from fsspec import AbstractFileSystem +from fsspec.callbacks import _DEFAULT_CALLBACK from fsspec.spec import AbstractBufferedFile import pyathena -from pyathena.filesystem.s3_object import S3Object, S3ObjectType, S3StorageClass +from pyathena.filesystem.s3_object import ( + S3CompleteMultipartUpload, + S3MultipartUpload, + S3MultipartUploadPart, + S3Object, + S3ObjectType, + S3PutObject, + S3StorageClass, +) from pyathena.util import RetryConfig, retry_api_call if TYPE_CHECKING: @@ -42,6 +52,7 @@ def __init__( default_block_size: Optional[int] = None, default_cache_type: Optional[str] = None, max_workers: int = (cpu_count() or 1) * 5, + s3_additional_kwargs=None, *args, **kwargs, ) -> None: @@ -62,6 +73,7 @@ def __init__( ) self.default_cache_type = default_cache_type if default_cache_type else "bytes" self.max_workers = max_workers + self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {} requester_pays = kwargs.pop("requester_pays", False) self.request_kwargs = {"RequestPayer": "requester"} if requester_pays else {} @@ -127,7 +139,7 @@ def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]: else: raise ValueError(f"Invalid S3 path format {path}.") - def _head_bucket(self, bucket, refresh: bool = False) -> Optional[Dict[Any, Any]]: + def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: if bucket not in self.dircache or refresh: try: retry_api_call( @@ -141,28 +153,42 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[Dict[Any, Any] return None raise file = S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_BUCKET, + "ETag": None, + "VersionId": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=bucket, key=None, - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_BUCKET, - etag=None, - ).to_dict() + version_id=None, + ) self.dircache[bucket] = file else: file = self.dircache[bucket] return file - def _head_object(self, path: str, refresh: bool = False) -> Optional[Dict[Any, Any]]: - bucket, key, _ = self.parse_path(path) + def _head_object( + self, path: str, version_id: Optional[str] = None, refresh: bool = False + ) -> Optional[S3Object]: + bucket, key, path_version_id = self.parse_path(path) + version_id = path_version_id if path_version_id else version_id if path not in self.dircache or refresh: try: + request = { + "Bucket": bucket, + "Key": key, + } + if version_id: + request.update({"VersionId": version_id}) response = retry_api_call( self._client.head_object, config=self._retry_config, logger=_logger, - Bucket=bucket, - Key=key, + **request, **self.request_kwargs, ) except botocore.exceptions.ClientError as e: @@ -170,21 +196,18 @@ def _head_object(self, path: str, refresh: bool = False) -> Optional[Dict[Any, A return None raise file = S3Object( + init=response, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, bucket=bucket, key=key, - size=response["ContentLength"], - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - storage_class=response.get( - "StorageClass", S3StorageClass.S3_STORAGE_CLASS_STANDARD - ), - etag=response["ETag"], - ).to_dict() + version_id=version_id, + ) self.dircache[path] = file else: file = self.dircache[path] return file - def _ls_buckets(self, refresh: bool = False) -> List[Dict[Any, Any]]: + def _ls_buckets(self, refresh: bool = False) -> List[S3Object]: if "" not in self.dircache or refresh: try: response = retry_api_call( @@ -198,13 +221,19 @@ def _ls_buckets(self, refresh: bool = False) -> List[Dict[Any, Any]]: raise buckets = [ S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_BUCKET, + "ETag": None, + "VersionId": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=b["Name"], key=None, - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_BUCKET, - etag=None, - ).to_dict() + version_id=None, + ) for b in response["Buckets"] ] self.dircache[""] = buckets @@ -220,12 +249,12 @@ def _ls_dirs( next_token=None, max_keys: Optional[int] = None, refresh: bool = False, - ) -> List[Dict[Any, Any]]: - bucket, key, _ = self.parse_path(path) + ) -> List[S3Object]: + bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" if path not in self.dircache or refresh: - files: List[Dict[Any, Any]] = [] + files: List[S3Object] = [] while True: request: Dict[Any, Any] = { "Bucket": bucket, @@ -245,24 +274,27 @@ def _ls_dirs( ) files.extend( S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=bucket, key=c["Prefix"][:-1], - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - etag=None, - ).to_dict() + version_id=version_id, + ) for c in response.get("CommonPrefixes", []) ) files.extend( S3Object( + init=c, + type=S3ObjectType.S3_OBJECT_TYPE_FILE, bucket=bucket, key=c["Key"], - size=c["Size"], - type=S3ObjectType.S3_OBJECT_TYPE_FILE, - storage_class=c["StorageClass"], - etag=c["ETag"], - ).to_dict() + ) for c in response.get("Contents", []) ) next_token = response.get("NextContinuationToken") @@ -284,27 +316,33 @@ def ls(self, path, detail=False, refresh=False, **kwargs): file = self._head_object(path, refresh=refresh) if file: files = [file] - return files if detail else [f["name"] for f in files] + return [f for f in files] if detail else [f.name for f in files] - def info(self, path, **kwargs): + def info(self, path, **kwargs) -> S3Object: refresh = kwargs.pop("refresh", False) path = self._strip_protocol(path) - bucket, key, _ = self.parse_path(path) + bucket, key, path_version_id = self.parse_path(path) + version_id = path_version_id if path_version_id else kwargs.pop("version_id", None) if path in ["/", ""]: return S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=bucket, key=path, - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - etag=None, - ).to_dict() + version_id=version_id, + ) if not refresh: - caches = self._ls_from_cache(path) + caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path) if caches is not None: if isinstance(caches, list): - cache = next((c for c in caches if c["name"] == path), None) - elif ("name" in caches) and (caches["name"] == path): + cache = next((c for c in caches if c.name == path), None) + elif caches.name == path: cache = caches else: cache = None @@ -313,15 +351,20 @@ def info(self, path, **kwargs): return cache else: return S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=bucket, key=path, - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - etag=None, - ).to_dict() + version_id=version_id, + ) if key: - info = self._head_object(path, refresh=refresh) + info = self._head_object(path, refresh=refresh, version_id=version_id) if info: return info @@ -341,13 +384,18 @@ def info(self, path, **kwargs): or response.get("CommonPrefixes", []) ): return S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, bucket=bucket, key=path, - size=0, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - storage_class=S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - etag=None, - ).to_dict() + version_id=version_id, + ) else: raise FileNotFoundError(path) @@ -365,9 +413,9 @@ def find(self, path, maxdepth=None, withdirs=None, detail=False, **kwargs): except FileNotFoundError: files = [] if detail: - return {f["name"]: f for f in files} + return {f.name: f for f in files} else: - return [f["name"] for f in files] + return [f.name for f in files] def exists(self, path, **kwargs): path = self._strip_protocol(path) @@ -379,8 +427,8 @@ def exists(self, path, **kwargs): try: if self._ls_from_cache(path): return True - file = self.info(path) - if file: + info = self.info(path) + if info: return True else: return False @@ -400,21 +448,75 @@ def exists(self, path, **kwargs): else: return False - def cp_file(self, path1, path2, **kwargs): + def rm_file(self, path): + # TODO + raise NotImplementedError # pragma: no cover + + def rmdir(self, path): + # TODO raise NotImplementedError # pragma: no cover def _rm(self, path): + # TODO raise NotImplementedError # pragma: no cover - def created(self, path): + def touch(self, path, truncate=True, **kwargs): + bucket, key, version_id = self.parse_path(path) + if version_id: + raise ValueError("Cannot touch the file with the version specified.") + if not truncate and self.exists(path): + raise ValueError("Cannot touch the existing file without specifying truncate.") + if not key: + raise ValueError("Cannot touch the bucket.") + + object = self._put_object(bucket=bucket, key=key, body=None, **kwargs) + self.invalidate_cache(self._parent(path)) + return object.to_dict() + + def cp_file(self, path1, path2, **kwargs): + # TODO raise NotImplementedError # pragma: no cover - def modified(self, path): + def cat_file(self, path, start=None, end=None, **kwargs): + # TODO + raise NotImplementedError # pragma: no cover + + def pipe_file(self, path, value, **kwargs): + # TODO + raise NotImplementedError # pragma: no cover + + def put_file(self, lpath, rpath, callback=_DEFAULT_CALLBACK, **kwargs): + # TODO + raise NotImplementedError # pragma: no cover + + def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs): + # TODO + raise NotImplementedError # pragma: no cover + + def checksum(self, path): + # TODO raise NotImplementedError # pragma: no cover def sign(self, path, expiration=100, **kwargs): + # TODO raise NotImplementedError # pragma: no cover + def created(self, path): + return self.modified(path) + + def modified(self, path): + info = self.info(path) + return info.get("last_modified") + + def invalidate_cache(self, path=None): + if path is None: + self.dircache.clear() + else: + path = self._strip_protocol(path) + while path: + self.dircache.pop(path, None) + path = self._parent(path) + def _open( self, path, @@ -440,6 +542,7 @@ def _open( cache_type=cache_type, autocommit=autocommit, cache_options=cache_options, + s3_additional_kwargs=self.s3_additional_kwargs, **kwargs, ) @@ -449,13 +552,12 @@ def _get_object( key: str, ranges: Tuple[int, int], version_id: Optional[str] = None, - kwargs: Optional[Dict[Any, Any]] = None, + **kwargs, ) -> Tuple[int, bytes]: range_ = f"bytes={ranges[0]}-{ranges[1] - 1}" request = {"Bucket": bucket, "Key": key, "Range": range_} if version_id: request.update({"versionId": version_id}) - kwargs = kwargs if kwargs else {} _logger.debug(f"Get object: s3://{bucket}/{key}?versionId={version_id} {range_}") response = retry_api_call( @@ -468,6 +570,124 @@ def _get_object( ) return ranges[0], cast(bytes, response["Body"].read()) + def _put_object(self, bucket: str, key: str, body: Optional[bytes], **kwargs) -> S3PutObject: + request: Dict[str, Any] = {"Bucket": bucket, "Key": key} + if body: + request.update({"Body": body}) + + _logger.debug(f"Put object: s3://{bucket}/{key}") + response = retry_api_call( + self._client.put_object, + config=self._retry_config, + logger=_logger, + **request, + **kwargs, + **self.request_kwargs, + ) + return S3PutObject(response) + + def _create_multipart_upload(self, bucket: str, key: str, **kwargs) -> S3MultipartUpload: + request = { + "Bucket": bucket, + "Key": key, + } + + _logger.debug(f"Create multipart upload to s3://{bucket}/{key}.") + response = retry_api_call( + self._client.create_multipart_upload, + config=self._retry_config, + logger=_logger, + **request, + **kwargs, + **self.request_kwargs, + ) + return S3MultipartUpload(response) + + def _upload_part_copy( + self, + bucket: str, + key: str, + copy_source: str, + upload_id: str, + part_number: int, + **kwargs, + ) -> S3MultipartUploadPart: + request = { + "Bucket": bucket, + "Key": key, + "CopySource": copy_source, + "UploadId": upload_id, + "PartNumber": part_number, + } + + _logger.debug( + f"Upload part copy from {copy_source} to s3://{bucket}/{key} as part {part_number}." + ) + response = retry_api_call( + self._client.upload_part_copy, + config=self._retry_config, + logger=_logger, + **request, + **kwargs, + **self.request_kwargs, + ) + return S3MultipartUploadPart(part_number, response) + + def _upload_part( + self, + bucket: str, + key: str, + upload_id: str, + part_number: int, + body: bytes, + **kwargs, + ) -> S3MultipartUploadPart: + request = { + "Bucket": bucket, + "Key": key, + "UploadId": upload_id, + "PartNumber": part_number, + "Body": body, + } + + _logger.debug(f"Upload part of {upload_id} to s3://{bucket}/{key} as part {part_number}.") + response = retry_api_call( + self._client.upload_part, + config=self._retry_config, + logger=_logger, + **request, + **kwargs, + **self.request_kwargs, + ) + return S3MultipartUploadPart(part_number, response) + + def _complete_multipart_upload( + self, bucket: str, key: str, upload_id: str, parts: List[Dict[str, Any]], **kwargs + ) -> S3CompleteMultipartUpload: + request = { + "Bucket": bucket, + "Key": key, + "UploadId": upload_id, + "MultipartUpload": {"Parts": parts}, + } + + _logger.debug(f"Complete multipart upload {upload_id} to s3://{bucket}/{key}.") + response = retry_api_call( + self._client.complete_multipart_upload, + config=self._retry_config, + logger=_logger, + **request, + **kwargs, + **self.request_kwargs, + ) + return S3CompleteMultipartUpload(response) + + def _call(self, method: str, **kwargs) -> Dict[str, Any]: + response = retry_api_call( + getattr(self._client, method), config=self._retry_config, logger=_logger, **kwargs + ) + return cast(Dict[str, Any], response) + class S3File(AbstractBufferedFile): def __init__( @@ -482,18 +702,23 @@ def __init__( autocommit: bool = True, cache_options: Optional[Dict[Any, Any]] = None, size: Optional[int] = None, + s3_additional_kwargs=None, ): + self.max_workers = max_workers + self.worker_block_size = block_size + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self.s3_additional_kwargs = s3_additional_kwargs if s3_additional_kwargs else {} + super().__init__( fs=fs, path=path, mode=mode, - block_size=block_size * max_workers, + block_size=block_size, autocommit=autocommit, cache_type=cache_type, cache_options=cache_options, size=size, ) - self.fs = fs bucket, key, path_version_id = S3FileSystem.parse_path(path) self.bucket = bucket if not key: @@ -510,20 +735,133 @@ def __init__( self.version_id = path_version_id else: self.version_id = version_id - self.max_workers = max_workers - self.worker_block_size = block_size - self._executor = ThreadPoolExecutor(max_workers=max_workers) - if "r" in mode and "etag" in self.details: - self.request_kwargs = {"IfMatch": self.details["etag"]} + + self.append_block = False + if "r" in mode: + info = self.fs.info(self.path, version_id=self.version_id) + if etag := info.get("etag"): + self.s3_additional_kwargs.update({"IfMatch": etag}) + self._details = info + elif "a" in mode and self.fs.exists(path): + self.append_block = True + info = self.fs.info(self.path, version_id=self.version_id) + self.loc = info.get("size", 0) + self.s3_additional_kwargs.update(info.to_api_repr()) + self._details = info + else: + self._details = {} + + self.multipart_upload: Optional[S3MultipartUpload] = None + self.multipart_upload_parts: List[Future[S3MultipartUploadPart]] = [] def close(self): super().close() self._executor.shutdown() def _initiate_upload(self): - raise NotImplementedError # pragma: no cover + if self.autocommit and not self.append_block and self.tell() < self.blocksize: + return + + self.multipart_upload = self.fs._create_multipart_upload( + bucket=self.bucket, + key=self.key, + **self.s3_additional_kwargs, + ) + if self.append_block: + self.multipart_upload_parts.append( + self._executor.submit( + self.fs._upload_part_copy, + bucket=self.bucket, + key=self.key, + upload_id=cast(str, cast(S3MultipartUpload, self.multipart_upload).upload_id), + part_number=1, + **self.s3_additional_kwargs, + ) + ) + + def _upload_chunk(self, final=False): + if not self.append_block and self.tell() < self.blocksize: + if self.autocommit and final: + self.commit() + return True + + if not self.multipart_upload: + raise RuntimeError("Multipart upload is not initialized.") + if self.autocommit and final: + self.commit() + else: + part_number = len(self.multipart_upload_parts) + self.buffer.seek(0) + while data := self.buffer.read(self.blocksize): + part_number += 1 + self.multipart_upload_parts.append( + self._executor.submit( + self.fs._upload_part, + bucket=self.bucket, + key=self.key, + upload_id=cast(str, self.multipart_upload.upload_id), + part_number=part_number, + body=data, + **self.s3_additional_kwargs, + ) + ) + return True + + def commit(self): + if self.tell() == 0: + if self.buffer is not None: + self.discard() + self.fs.touch(self.path, **self.s3_additional_kwargs) + elif not self.multipart_upload_parts: + if self.buffer is not None: + self.buffer.seek(0) + data = self.buffer.read() + self.fs._put_object( + bucket=self.bucket, + key=self.key, + body=data, + **self.s3_additional_kwargs, + ) + else: + if not self.multipart_upload: + raise RuntimeError("Multipart upload is not initialized.") + + parts: List[Dict[str, Any]] = [] + for f in as_completed(self.multipart_upload_parts): + result = f.result() + part = { + "ETag": result.etag, + "PartNumber": result.part_number, + } + if result.checksum_sha256: + part.update({"ChecksumSHA256": result.checksum_sha256}) + parts.append(part) + parts.sort(key=lambda x: x["PartNumber"]) + self.fs._complete_multipart_upload( + bucket=self.bucket, + key=self.key, + upload_id=cast(str, self.multipart_upload.upload_id), + parts=parts, + ) + + self.fs.invalidate_cache(self.path) + + def discard(self): + if self.multipart_upload: + for f in self.multipart_upload_parts: + f.cancel() + self.fs._call( + "abort_multipart_upload", + Bucket=self.bucket, + Key=self.key, + UploadId=self.multipart_upload.upload_id, + **self.s3_additional_kwargs, + ) + + self.multipart_upload = None + self.multipart_upload_parts = [] - def _fetch_range(self, start, end): + def _fetch_range(self, start: int, end: int) -> bytes: ranges = self._get_ranges( start, end, max_workers=self.max_workers, worker_block_size=self.worker_block_size ) @@ -531,18 +869,28 @@ def _fetch_range(self, start, end): object_ = self._merge_objects( list( self._executor.map( - self.fs._get_object, + lambda bucket, key, ranges, version_id, kwargs: self.fs._get_object( + bucket=bucket, + key=key, + ranges=ranges, + version_id=version_id, + **kwargs, + ), itertools.repeat(self.bucket), itertools.repeat(self.key), ranges, itertools.repeat(self.version_id), - itertools.repeat(self.request_kwargs), + itertools.repeat(self.s3_additional_kwargs), ) ) ) else: object_ = self.fs._get_object( - self.bucket, self.key, ranges[0], self.version_id, self.request_kwargs + self.bucket, + self.key, + ranges[0], + self.version_id, + **self.s3_additional_kwargs, )[1] return object_ diff --git a/pyathena/filesystem/s3_object.py b/pyathena/filesystem/s3_object.py index 524b32b4..910d94bd 100644 --- a/pyathena/filesystem/s3_object.py +++ b/pyathena/filesystem/s3_object.py @@ -1,11 +1,32 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy import logging -from dataclasses import dataclass -from typing import Any, Dict, Optional +from datetime import datetime +from typing import Any, Dict, Iterator, MutableMapping, Optional _logger = logging.getLogger(__name__) # type: ignore +_API_FIELD_TO_S3_OBJECT_PROPERTY = { + "ETag": "etag", + "CacheControl": "cache_control", + "ContentDisposition": "content_disposition", + "ContentEncoding": "content_encoding", + "ContentLanguage": "content_language", + "ContentLength": "content_length", + "ContentType": "content_type", + "Expires": "expires", + "WebsiteRedirectLocation": "website_redirect_location", + "ServerSideEncryption": "server_side_encryption", + "SSECustomerAlgorithm": "sse_customer_algorithm", + "SSEKMSKeyId": "sse_kms_key_id", + "BucketKeyEnabled": "bucket_key_enabled", + "StorageClass": "storage_class", + "ObjectLockMode": "object_lock_mode", + "ObjectLockRetainUntilDate": "object_lock_retain_until_date", + "ObjectLockLegalHoldStatus": "object_lock_legal_hold_status", + "Metadata": "metadata", +} class S3ObjectType: @@ -28,20 +49,358 @@ class S3StorageClass: S3_STORAGE_CLASS_DIRECTORY = "DIRECTORY" -@dataclass -class S3Object: - bucket: str - key: Optional[str] - size: int - type: str - storage_class: str - etag: Optional[str] - - def __post_init__(self) -> None: - if self.key is None: - self.name = self.bucket +class S3Object(MutableMapping[str, Any]): + def __init__( + self, + init: Dict[str, Any], + **kwargs, + ): + if init: + super().update({_API_FIELD_TO_S3_OBJECT_PROPERTY.get(k, k): v for k, v in init.items()}) + if "Size" in init: + self.content_length = init["Size"] + self.size = init["Size"] + if "ContentLength" in init: + self.size = init["ContentLength"] + super().update({_API_FIELD_TO_S3_OBJECT_PROPERTY.get(k, k): v for k, v in kwargs.items()}) + if self.get("key") is None: + self.name = self.get("bucket") else: - self.name = f"{self.bucket}/{self.key}" + self.name = f"{self.get('bucket')}/{self.get('key')}" + + def get(self, key: str, default: Any = None) -> Any: + return super().get(key, default) + + def __getitem__(self, item: str) -> Any: + return self.__dict__.get(item) + + def __getattr__(self, item: str): + return self.get(item) + + def __setitem__(self, key: str, value: Any) -> None: + self.__dict__[key] = value + + def __setattr__(self, attr: str, value: Any) -> None: + self[attr] = value + + def __delitem__(self, key: str) -> None: + del self.__dict__[key] + + def __iter__(self) -> Iterator[str]: + return iter(self.__dict__.keys()) + + def __len__(self) -> int: + return len(self.__dict__) def to_dict(self) -> Dict[str, Any]: - return self.__dict__ + return copy.deepcopy(self.__dict__) + + def to_api_repr(self) -> Dict[str, Any]: + fields = {} + for k, v in _API_FIELD_TO_S3_OBJECT_PROPERTY.items(): + field = self.get(v) + if field is not None: + fields[k] = field + return fields + + +class S3PutObject: + def __init__(self, response: Dict[str, Any]): + self._expiration: Optional[str] = response.get("Expiration") + self._version_id: Optional[str] = response.get("VersionId") + self._etag: Optional[str] = response.get("ETag") + self._checksum_crc32: Optional[str] = response.get("ChecksumCRC32") + self._checksum_crc32c: Optional[str] = response.get("ChecksumCRC32C") + self._checksum_sha1: Optional[str] = response.get("ChecksumSHA1") + self._checksum_sha256: Optional[str] = response.get("ChecksumSHA256") + self._server_side_encryption = response.get("ServerSideEncryption") + self._sse_customer_algorithm = response.get("SSECustomerAlgorithm") + self._sse_customer_key_md5 = response.get("SSECustomerKeyMD5") + self._sse_kms_key_id = response.get("SSEKMSKeyId") + self._sse_kms_encryption_context = response.get("SSEKMSEncryptionContext") + self._bucket_key_enabled = response.get("BucketKeyEnabled") + self._request_charged = response.get("RequestCharged") + self._checksum_algorithm = response.get("ChecksumAlgorithm") + + @property + def expiration(self) -> Optional[str]: + return self._expiration + + @property + def version_id(self) -> Optional[str]: + return self._version_id + + @property + def etag(self) -> Optional[str]: + return self._etag + + @property + def checksum_crc32(self) -> Optional[str]: + return self._checksum_crc32 + + @property + def checksum_crc32c(self) -> Optional[str]: + return self._checksum_crc32c + + @property + def checksum_sha1(self) -> Optional[str]: + return self._checksum_sha1 + + @property + def checksum_sha256(self) -> Optional[str]: + return self._checksum_sha256 + + @property + def server_side_encryption(self) -> Optional[str]: + return self._server_side_encryption + + @property + def sse_customer_algorithm(self) -> Optional[str]: + return self._sse_customer_algorithm + + @property + def sse_customer_key_md5(self) -> Optional[str]: + return self._sse_customer_key_md5 + + @property + def sse_kms_key_id(self) -> Optional[str]: + return self._sse_kms_key_id + + @property + def sse_kms_encryption_context(self) -> Optional[str]: + return self._sse_kms_encryption_context + + @property + def bucket_key_enabled(self) -> Optional[bool]: + return self._bucket_key_enabled + + @property + def request_charged(self) -> Optional[str]: + return self._request_charged + + def to_dict(self): + return copy.deepcopy(self.__dict__) + + +class S3MultipartUpload: + def __init__(self, response: Dict[str, Any]): + self._abort_date = response.get("AbortDate") + self._abort_rule_id = response.get("AbortRuleId") + self._bucket = response.get("Bucket") + self._key = response.get("Key") + self._upload_id = response.get("UploadId") + self._server_side_encryption = response.get("ServerSideEncryption") + self._sse_customer_algorithm = response.get("SSECustomerAlgorithm") + self._sse_customer_key_md5 = response.get("SSECustomerKeyMD5") + self._sse_kms_key_id = response.get("SSEKMSKeyId") + self._sse_kms_encryption_context = response.get("SSEKMSEncryptionContext") + self._bucket_key_enabled = response.get("BucketKeyEnabled") + self._request_charged = response.get("RequestCharged") + self._checksum_algorithm = response.get("ChecksumAlgorithm") + + @property + def abort_date(self) -> Optional[datetime]: + return self._abort_date + + @property + def abort_rule_id(self) -> Optional[str]: + return self._abort_rule_id + + @property + def bucket(self) -> Optional[str]: + return self._bucket + + @property + def key(self) -> Optional[str]: + return self._key + + @property + def upload_id(self) -> Optional[str]: + return self._upload_id + + @property + def server_side_encryption(self) -> Optional[str]: + return self._server_side_encryption + + @property + def sse_customer_algorithm(self) -> Optional[str]: + return self._sse_customer_algorithm + + @property + def sse_customer_key_md5(self) -> Optional[str]: + return self._sse_customer_key_md5 + + @property + def sse_kms_key_id(self) -> Optional[str]: + return self._sse_kms_key_id + + @property + def sse_kms_encryption_context(self) -> Optional[str]: + return self._sse_kms_encryption_context + + @property + def bucket_key_enabled(self) -> Optional[bool]: + return self._bucket_key_enabled + + @property + def request_charged(self) -> Optional[str]: + return self._request_charged + + @property + def checksum_algorithm(self) -> Optional[str]: + return self._checksum_algorithm + + +class S3MultipartUploadPart: + def __init__(self, part_number: int, response: Dict[str, Any]): + self._part_number = part_number + self._copy_source_version_id: Optional[str] = response.get("CopySourceVersionId") + copy_part_result = response.get("CopyPartResult") + if copy_part_result: + self._last_modified: Optional[datetime] = copy_part_result.get("LastModified") + self._etag: Optional[str] = copy_part_result.get("ETag") + self._checksum_crc32: Optional[str] = copy_part_result.get("ChecksumCRC32") + self._checksum_crc32c: Optional[str] = copy_part_result.get("ChecksumCRC32C") + self._checksum_sha1: Optional[str] = copy_part_result.get("ChecksumSHA1") + self._checksum_sha256: Optional[str] = copy_part_result.get("ChecksumSHA256") + else: + self._last_modified = None + self._etag = response.get("ETag") + self._checksum_crc32 = response.get("ChecksumCRC32") + self._checksum_crc32c = response.get("ChecksumCRC32C") + self._checksum_sha1 = response.get("ChecksumSHA1") + self._checksum_sha256 = response.get("ChecksumSHA256") + self._server_side_encryption: Optional[str] = response.get("ServerSideEncryption") + self._sse_customer_algorithm: Optional[str] = response.get("SSECustomerAlgorithm") + self._sse_customer_key_md5: Optional[str] = response.get("SSECustomerKeyMD5") + self._sse_kms_key_id: Optional[str] = response.get("SSEKMSKeyId") + self._bucket_key_enabled: Optional[bool] = response.get("BucketKeyEnabled") + self._request_charged: Optional[str] = response.get("RequestCharged") + + @property + def part_number(self) -> int: + return self._part_number + + @property + def copy_source_version_id(self) -> Optional[str]: + return self._copy_source_version_id + + @property + def last_modified(self) -> Optional[datetime]: + return self._last_modified + + @property + def etag(self) -> Optional[str]: + return self._etag + + @property + def checksum_crc32(self) -> Optional[str]: + return self._checksum_crc32 + + @property + def checksum_crc32c(self) -> Optional[str]: + return self._checksum_crc32c + + @property + def checksum_sha1(self) -> Optional[str]: + return self._checksum_sha1 + + @property + def checksum_sha256(self) -> Optional[str]: + return self._checksum_sha256 + + @property + def server_side_encryption(self) -> Optional[str]: + return self._server_side_encryption + + @property + def sse_customer_algorithm(self) -> Optional[str]: + return self._sse_customer_algorithm + + @property + def sse_customer_key_md5(self) -> Optional[str]: + return self._sse_customer_key_md5 + + @property + def sse_kms_key_id(self) -> Optional[str]: + return self._sse_kms_key_id + + @property + def bucket_key_enabled(self) -> Optional[bool]: + return self._bucket_key_enabled + + @property + def request_charged(self) -> Optional[str]: + return self._request_charged + + def to_api_repr(self) -> Dict[str, Any]: + return { + "ETag": self.etag, + "ChecksumCRC32": self.checksum_crc32, + "ChecksumCRC32C": self.checksum_crc32c, + "ChecksumSHA1": self.checksum_sha1, + "ChecksumSHA256": self.checksum_sha256, + "PartNumber": self.part_number, + } + + +class S3CompleteMultipartUpload: + def __init__(self, response: Dict[str, Any]): + self._expiration: Optional[str] = response.get("Expiration") + self._version_id: Optional[str] = response.get("VersionId") + self._etag: Optional[str] = response.get("ETag") + self._checksum_crc32: Optional[str] = response.get("ChecksumCRC32") + self._checksum_crc32c: Optional[str] = response.get("ChecksumCRC32C") + self._checksum_sha1: Optional[str] = response.get("ChecksumSHA1") + self._checksum_sha256: Optional[str] = response.get("ChecksumSHA256") + self._server_side_encryption = response.get("ServerSideEncryption") + self._sse_kms_key_id = response.get("SSEKMSKeyId") + self._bucket_key_enabled = response.get("BucketKeyEnabled") + self._request_charged = response.get("RequestCharged") + + @property + def expiration(self) -> Optional[str]: + return self._expiration + + @property + def version_id(self) -> Optional[str]: + return self._version_id + + @property + def etag(self) -> Optional[str]: + return self._etag + + @property + def checksum_crc32(self) -> Optional[str]: + return self._checksum_crc32 + + @property + def checksum_crc32c(self) -> Optional[str]: + return self._checksum_crc32c + + @property + def checksum_sha1(self) -> Optional[str]: + return self._checksum_sha1 + + @property + def checksum_sha256(self) -> Optional[str]: + return self._checksum_sha256 + + @property + def server_side_encryption(self) -> Optional[str]: + return self._server_side_encryption + + @property + def sse_kms_key_id(self) -> Optional[str]: + return self._sse_kms_key_id + + @property + def bucket_key_enabled(self) -> Optional[bool]: + return self._bucket_key_enabled + + @property + def request_charged(self) -> Optional[str]: + return self._request_charged + + def to_dict(self): + return copy.deepcopy(self.__dict__)