Skip to content

Commit

Permalink
Add test to validate that client is being reused
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermef committed May 1, 2022
1 parent 2f18b30 commit e4a0c42
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 14 deletions.
8 changes: 6 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from thumbor.importer import Importer
from thumbor.testing import TestCase

from thumbor_aws.s3_client import S3Client
import thumbor_aws.s3_client


class BaseS3TestCase(TestCase):
test_images = {}

def setUp(self):
super().setUp()
thumbor_aws.s3_client.S3_CLIENT = None

@property
def bucket_name(self):
"""Name of the bucket to put test files in"""
Expand Down Expand Up @@ -106,7 +110,7 @@ def get_context(self) -> Context:

async def ensure_bucket(self):
"""Ensures the test bucket is created"""
s3client = S3Client(self.context)
s3client = thumbor_aws.s3_client.S3Client(self.context)
if self.context.config.THUMBOR_AWS_RUN_IN_COMPATIBILITY_MODE is True:
s3client.configuration["region_name"] = self.config.TC_AWS_REGION
s3client.configuration[
Expand Down
35 changes: 35 additions & 0 deletions tests/test_s3_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-

# thumbor aws extensions
# https://github.com/thumbor/thumbor-aws

# Licensed under the MIT license:
# http://www.opensource.org/licenses/mit-license
# Copyright (c) 2021 Bernardo Heynemann heynemann@gmail.com


import pytest
from preggy import expect
from tornado.testing import gen_test

from tests import BaseS3TestCase

import thumbor_aws.s3_client


@pytest.mark.usefixtures("test_images")
class S3ClientTestCase(BaseS3TestCase):
@gen_test
async def test_should_reuse_the_same_client(self):
"""
Verifies that the S3 client module will
reuse the same AioBoto client.
"""
s3client = await thumbor_aws.s3_client.S3Client(
self.context
).get_client()
s3client2 = await thumbor_aws.s3_client.S3Client(
self.context
).get_client()
expect(s3client).to_equal(s3client2)
26 changes: 14 additions & 12 deletions thumbor_aws/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from thumbor.utils import logger

_default = object()
S3_CLIENT = None


class S3Client:
__client: AioBaseClient = None
context: Context = None
configuration: Dict[str, object] = None

Expand Down Expand Up @@ -88,16 +88,17 @@ def session(self) -> AioSession:

async def get_client(self) -> AioBaseClient:
"""Singleton client for S3"""
if self.__client is None:
global S3_CLIENT # pylint: disable=global-statement
if S3_CLIENT is None:
client = self.session.create_client(
"s3",
region_name=self.region_name,
aws_secret_access_key=self.secret_access_key,
aws_access_key_id=self.access_key_id,
endpoint_url=self.endpoint_url,
)
self.__client = await client.__aenter__()
return self.__client
S3_CLIENT = await client.__aenter__()
return S3_CLIENT

async def upload(
self,
Expand All @@ -109,6 +110,7 @@ async def upload(
"""Uploads a File to S3"""

client = await self.get_client()

response = None
try:
settings = dict(
Expand All @@ -127,7 +129,9 @@ async def upload(
raise RuntimeError(msg) # pylint: disable=raise-missing-from
status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
msg = (
f"Unable to upload image to {path}: Status Code {status_code}"
)
logger.error(msg)
raise RuntimeError(msg)

Expand All @@ -138,9 +142,7 @@ async def upload(
"Location Headers was not found in response"
)
logger.warning(msg)
location = default_location.format(
bucket_name=self.bucket_name
)
location = default_location.format(bucket_name=self.bucket_name)

return f"{location.rstrip('/')}/{path.lstrip('/')}"

Expand All @@ -159,7 +161,9 @@ async def get_data(

status_code = self.get_status_code(response)
if status_code != 200:
msg = f"Unable to upload image to {path}: Status Code {status_code}"
msg = (
f"Unable to upload image to {path}: Status Code {status_code}"
)
logger.error(msg)
return status_code, msg, None

Expand All @@ -176,9 +180,7 @@ async def object_exists(self, filepath: str):

client = await self.get_client()
try:
await client.get_object_acl(
Bucket=self.bucket_name, Key=filepath
)
await client.get_object_acl(Bucket=self.bucket_name, Key=filepath)
return True
except client.exceptions.NoSuchKey:
return False
Expand Down

0 comments on commit e4a0c42

Please sign in to comment.