Skip to content

Commit

Permalink
tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneier committed Sep 4, 2023
1 parent 79c25ad commit 1f45042
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 50 deletions.
2 changes: 1 addition & 1 deletion storages/backends/s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from storages.base import BaseStorage
from storages.compress import CompressedFileMixin
from storages.compress import CompressStorageMixin
from storages.utils import ReadBytesWrapper
from storages.utils import check_location
from storages.utils import clean_name
from storages.utils import get_available_overwrite_name
from storages.utils import is_seekable
from storages.utils import ReadBytesWrapper
from storages.utils import safe_join
from storages.utils import setting
from storages.utils import to_bytes
Expand Down
17 changes: 7 additions & 10 deletions storages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import SuspiciousFileOperation
from django.core.files.utils import FileProxyMixin
from django.utils.encoding import force_bytes


Expand Down Expand Up @@ -127,7 +128,7 @@ def is_seekable(file_object):
return not hasattr(file_object, 'seekable') or file_object.seekable()


class ReadBytesWrapper:
class ReadBytesWrapper(FileProxyMixin):
"""
A wrapper for a file-like object, that makes read() always returns bytes.
"""
Expand All @@ -138,20 +139,16 @@ def __init__(self, file, encoding=None):
If not provided will default to file.encoding, of if that's not available,
to utf-8.
"""
self._file = file
self.encoding = (
self.file = file
self._encoding = (
encoding
or getattr(file, "encoding", None)
or "utf-8"
)

def read(self, *args, **kwargs):
content = self._file.read(*args, **kwargs)
content = self.file.read(*args, **kwargs)

if not isinstance(content, bytes):
return content.encode(self.encoding)
else:
return content

def seek(self, *args, **kwargs):
return self._file.seek(*args, **kwargs)
content = content.encode(self._encoding)
return content
4 changes: 1 addition & 3 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@
USE_TZ = True

# the following test settings are required for moto to work.
AWS_STORAGE_BUCKET_NAME = "test_bucket"
AWS_ACCESS_KEY_ID = "testing_key_id"
AWS_SECRET_ACCESS_KEY = "testing_access_key"
AWS_STORAGE_BUCKET_NAME = "test-bucket"
65 changes: 29 additions & 36 deletions tests/test_s3boto3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import io
import pickle
import threading
from datetime import datetime
from textwrap import dedent
from unittest import mock
from unittest import skipIf
from urllib.parse import urlparse

import boto3
import boto3.s3.transfer
from botocore.exceptions import ClientError
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.files.base import ContentFile
from django.core.files.base import File
from django.test import TestCase
from django.test import override_settings
from django.utils.timezone import is_aware
Expand All @@ -34,11 +35,11 @@ def setUp(self):
self.storage._connections.connection = mock.MagicMock()

def test_s3_session(self):
settings.AWS_S3_SESSION_PROFILE = "test_profile"
with mock.patch('boto3.Session') as mock_session:
storage = s3boto3.S3Boto3Storage()
_ = storage.connection
mock_session.assert_called_once_with(profile_name="test_profile")
with override_settings(AWS_S3_SESSION_PROFILE="test_profile"):
with mock.patch('boto3.Session') as mock_session:
storage = s3boto3.S3Boto3Storage()
_ = storage.connection
mock_session.assert_called_once_with(profile_name="test_profile")

def test_pickle_with_bucket(self):
"""
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_storage_save(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
},
Expand All @@ -114,7 +115,7 @@ def test_storage_save_non_seekable(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'text/plain',
},
Expand Down Expand Up @@ -174,7 +175,7 @@ def test_content_type(self):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
'ContentType': 'image/jpeg',
},
Expand All @@ -189,7 +190,7 @@ def test_storage_save_gzipped(self):
content = ContentFile("I am gzip'd")
self.storage.save(name, content)
obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
obj.upload_fileobj.assert_called_once_with(
mock.ANY,
ExtraArgs={
'ContentType': 'application/octet-stream',
Expand All @@ -210,7 +211,7 @@ def get_object_parameters(name):

obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
mock.ANY,
ExtraArgs={
"ContentType": "application/gzip",
},
Expand All @@ -225,8 +226,8 @@ def test_storage_save_gzipped_non_seekable(self):
content = NonSeekableContentFile("I am gzip'd")
self.storage.save(name, content)
obj = self.storage.bucket.Object.return_value
obj.upload_fileobj.assert_called_with(
content,
obj.upload_fileobj.assert_called_once_with(
mock.ANY,
ExtraArgs={
'ContentType': 'application/octet-stream',
'ContentEncoding': 'gzip',
Expand Down Expand Up @@ -617,7 +618,7 @@ def test_storage_listdir_empty(self):
self.storage._connections.connection.meta.client.get_paginator.return_value = paginator

dirs, files = self.storage.listdir('dir/')
paginator.paginate.assert_called_with(Bucket=None, Delimiter='/', Prefix='dir/')
paginator.paginate.assert_called_with(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Delimiter='/', Prefix='dir/')

self.assertEqual(dirs, [])
self.assertEqual(files, [])
Expand Down Expand Up @@ -868,55 +869,47 @@ def test_closed(self):
f.close()
self.assertTrue(f.closed)


@mock_s3
class S3Boto3StorageTestsWithMoto(TestCase):
"""
These tests use the moto library to mock S3, rather than unittest.mock.
This is better because more of boto3's internal code will be run in tests.
For example this issue
https://github.com/jschneier/django-storages/issues/708
wouldn't be caught using unittest.mock, since the error occurs in boto3's internals.
Using mock_s3 as a class decorator automatically decorates methods,
but NOT classmethods or staticmethods.
"""
@classmethod
@mock_s3
def setUpClass(cls):
super().setUpClass()
# create a bucket specified in settings.
cls.bucket = boto3.resource("s3").Bucket(settings.AWS_STORAGE_BUCKET_NAME)

def setUp(cls):
super().setUp()

cls.storage = s3boto3.S3Boto3Storage()
cls.bucket = cls.storage.connection.Bucket(settings.AWS_STORAGE_BUCKET_NAME)
cls.bucket.create()
# create a S3Boto3Storage backend instance.
cls.s3boto3_storage = s3boto3.S3Boto3Storage()

def test_save_bytes_file(self):
self.s3boto3_storage.save("bytes_file.txt", File(io.BytesIO(b"foo1")))
self.storage.save("bytes_file.txt", File(io.BytesIO(b"foo1")))

self.assertEqual(
b"foo1",
self.bucket.Object("bytes_file.txt").get()['Body'].read(),
)

def test_save_string_file(self):
self.s3boto3_storage.save("string_file.txt", File(io.StringIO("foo2")))
self.storage.save("string_file.txt", File(io.StringIO("foo2")))

self.assertEqual(
b"foo2",
self.bucket.Object("string_file.txt").get()['Body'].read(),
)

def test_save_bytes_content_file(self):
self.s3boto3_storage.save("bytes_content.txt", ContentFile(b"foo3"))
self.storage.save("bytes_content.txt", ContentFile(b"foo3"))

self.assertEqual(
b"foo3",
self.bucket.Object("bytes_content.txt").get()['Body'].read(),
)

def test_save_string_content_file(self):
self.s3boto3_storage.save("string_content.txt", ContentFile("foo4"))
self.storage.save("string_content.txt", ContentFile("foo4"))

self.assertEqual(
b"foo4",
Expand All @@ -930,7 +923,7 @@ def test_content_type_guess(self):
name = 'test_image.jpg'
content = ContentFile(b'data')
content.content_type = None
self.s3boto3_storage.save(name, content)
self.storage.save(name, content)

s3_object_fetched = self.bucket.Object(name).get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
Expand All @@ -942,7 +935,7 @@ def test_content_type_attribute(self):
"""
content = ContentFile(b'data')
content.content_type = "test/foo"
self.s3boto3_storage.save("test_file", content)
self.storage.save("test_file", content)

s3_object_fetched = self.bucket.Object("test_file").get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
Expand All @@ -954,7 +947,7 @@ def test_content_type_not_detectable(self):
"""
content = ContentFile(b'data')
content.content_type = None
self.s3boto3_storage.save("test_file", content)
self.storage.save("test_file", content)

s3_object_fetched = self.bucket.Object("test_file").get()
self.assertEqual(b"data", s3_object_fetched['Body'].read())
Expand Down

0 comments on commit 1f45042

Please sign in to comment.