Skip to content

Commit

Permalink
reuse existing client fixes #15
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermef committed Apr 26, 2022
1 parent 74753f3 commit a7981a2
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 97 deletions.
16 changes: 8 additions & 8 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ async def ensure_bucket(self):
location = {
"LocationConstraint": self.region_name,
}
async with s3client.get_client() as client:
try:
await client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration=location,
)
except client.exceptions.BucketAlreadyOwnedByYou:
pass
client = await s3client.get_client()
try:
await client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration=location,
)
except client.exceptions.BucketAlreadyOwnedByYou:
pass
159 changes: 80 additions & 79 deletions thumbor_aws/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class S3Client:
__session: AioSession = None
__client: AioBaseClient = None
context: Context = None
configuration: Dict[str, object] = None

Expand Down Expand Up @@ -83,20 +83,21 @@ def file_acl(self) -> str:

@property
def session(self) -> AioSession:
"""Singleton Session used for connecting with AWS"""
if self.__session is None:
self.__session = get_session()
return self.__session

def get_client(self) -> AioBaseClient:
"""Gets a connected client to use for S3"""
return self.session.create_client(
"s3",
region_name=self.region_name,
aws_secret_access_key=self.secret_access_key,
aws_access_key_id=self.access_key_id,
endpoint_url=self.endpoint_url,
)
"""Session used for connecting with AWS"""
return get_session()

async def get_client(self) -> AioBaseClient:
"""Singleton client for S3"""
if self.__client is None:
client = self.session.create_client(
"s3",
region_name=self.region_name,
aws_secret_access_key=self.secret_access_key,
aws_access_key_id=self.access_key_id,
endpoint_url=self.endpoint_url,
)
self.__client = await client.__aenter__()
return self.__client

async def upload(
self,
Expand All @@ -107,88 +108,88 @@ async def upload(
) -> str:
"""Uploads a File to S3"""

async with self.get_client() as client:
response = None
try:
settings = dict(
Bucket=self.bucket_name,
Key=path,
Body=data,
ContentType=content_type,
)
if self.file_acl is not None:
settings["ACL"] = self.file_acl

response = await client.put_object(**settings)
except Exception as error:
msg = f"Unable to upload image to {path}: {error} ({type(error)})"
logger.error(msg)
raise RuntimeError(msg) # pylint: disable=raise-missing-from
status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
logger.error(msg)
raise RuntimeError(msg)

location = self.get_location(response)
if location is None:
msg = (
f"Unable to process response from AWS to {path}: "
"Location Headers was not found in response"
)
logger.warning(msg)
location = default_location.format(
bucket_name=self.bucket_name
)

return f"{location.rstrip('/')}/{path.lstrip('/')}"
client = await self.get_client()
response = None
try:
settings = dict(
Bucket=self.bucket_name,
Key=path,
Body=data,
ContentType=content_type,
)
if self.file_acl is not None:
settings["ACL"] = self.file_acl

response = await client.put_object(**settings)
except Exception as error:
msg = f"Unable to upload image to {path}: {error} ({type(error)})"
logger.error(msg)
raise RuntimeError(msg) # pylint: disable=raise-missing-from
status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
logger.error(msg)
raise RuntimeError(msg)

location = self.get_location(response)
if location is None:
msg = (
f"Unable to process response from AWS to {path}: "
"Location Headers was not found in response"
)
logger.warning(msg)
location = default_location.format(
bucket_name=self.bucket_name
)

return f"{location.rstrip('/')}/{path.lstrip('/')}"

async def get_data(
self, path: str, expiration: int = _default
) -> (int, bytes, Optional[datetime.datetime]):
"""Gets an object's data from S3"""

async with self.get_client() as client:
try:
response = await client.get_object(
Bucket=self.bucket_name, Key=path
)
except client.exceptions.NoSuchKey:
return 404, b"", None
client = await self.get_client()
try:
response = await client.get_object(
Bucket=self.bucket_name, Key=path
)
except client.exceptions.NoSuchKey:
return 404, b"", None

status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
logger.error(msg)
return status_code, msg, None
status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
logger.error(msg)
return status_code, msg, None

last_modified = response["LastModified"]
if self._is_expired(last_modified, expiration):
return 410, b"", last_modified
last_modified = response["LastModified"]
if self._is_expired(last_modified, expiration):
return 410, b"", last_modified

body = await self.get_body(response)
body = await self.get_body(response)

return status_code, body, last_modified
return status_code, body, last_modified

async def object_exists(self, filepath: str):
"""Detects whether an object exists in S3"""

async with self.get_client() as client:
try:
await client.get_object_acl(
Bucket=self.bucket_name, Key=filepath
)
return True
except client.exceptions.NoSuchKey:
return False
client = await self.get_client()
try:
await client.get_object_acl(
Bucket=self.bucket_name, Key=filepath
)
return True
except client.exceptions.NoSuchKey:
return False

async def get_object_acl(self, filepath: str):
"""Gets an object's metadata"""

async with self.get_client() as client:
return await client.get_object_acl(
Bucket=self.bucket_name, Key=filepath
)
client = await self.get_client()
return await client.get_object_acl(
Bucket=self.bucket_name, Key=filepath
)

def get_status_code(self, response: Mapping[str, Any]) -> int:
"""Gets the status code from an AWS response object"""
Expand Down
20 changes: 10 additions & 10 deletions thumbor_aws/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,17 @@ async def remove(self, path: str):
if not exists:
return

async with self.get_client() as client:
normalized_path = self.normalize_path(path)
response = await client.delete_object(
Bucket=self.bucket_name,
Key=normalized_path,
client = await self.get_client()
normalized_path = self.normalize_path(path)
response = await client.delete_object(
Bucket=self.bucket_name,
Key=normalized_path,
)
status = self.get_status_code(response)
if status >= 300:
raise RuntimeError(
f"Failed to remove {normalized_path}: Status {status}"
)
status = self.get_status_code(response)
if status >= 300:
raise RuntimeError(
f"Failed to remove {normalized_path}: Status {status}"
)

def normalize_path(self, path: str) -> str:
"""Returns the path used for storage"""
Expand Down

0 comments on commit a7981a2

Please sign in to comment.