Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[general] Write overwriting files in terms of .exists() #1422

Merged
merged 1 commit into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions storages/backends/azure_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from storages.base import BaseStorage
from storages.utils import clean_name
from storages.utils import get_available_overwrite_name
from storages.utils import safe_join
from storages.utils import setting
from storages.utils import to_bytes
Expand Down Expand Up @@ -232,20 +231,13 @@ def _get_valid_path(self, name):
def _open(self, name, mode="rb"):
return AzureStorageFile(name, mode, self)

def get_available_name(self, name, max_length=_AZURE_NAME_MAX_LEN):
"""
Returns a filename that's free on the target storage system, and
available for new content to be written to.
"""
name = clean_name(name)
if self.overwrite_files:
return get_available_overwrite_name(name, max_length)
return super().get_available_name(name, max_length)

def exists(self, name):
if not name:
return True

if self.overwrite_files:
return False

blob_client = self.client.get_blob_client(self._get_valid_path(name))
return blob_client.exists()

Expand Down
11 changes: 3 additions & 8 deletions storages/backends/dropbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from dropbox.files import WriteMode

from storages.base import BaseStorage
from storages.utils import get_available_overwrite_name
from storages.utils import setting

_DEFAULT_TIMEOUT = 100
Expand Down Expand Up @@ -130,6 +129,9 @@ def delete(self, name):
self.client.files_delete(self._full_path(name))

def exists(self, name):
if self.write_mode == "overwrite":
return False

try:
return bool(self.client.files_get_metadata(self._full_path(name)))
except ApiError:
Expand Down Expand Up @@ -199,12 +201,5 @@ def _chunked_upload(self, content, dest_path):
)
cursor.offset = content.tell()

def get_available_name(self, name, max_length=None):
"""Overwrite existing file with the same name."""
name = self._full_path(name)
if self.write_mode == "overwrite":
return get_available_overwrite_name(name, max_length)
return super().get_available_name(name, max_length)


DropBoxStorage = DropboxStorage
10 changes: 3 additions & 7 deletions storages/backends/gcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from storages.compress import CompressedFileMixin
from storages.utils import check_location
from storages.utils import clean_name
from storages.utils import get_available_overwrite_name
from storages.utils import safe_join
from storages.utils import setting
from storages.utils import to_bytes
Expand Down Expand Up @@ -243,6 +242,9 @@ def exists(self, name):
except NotFound:
return False

if self.file_overwrite:
return False

name = self._normalize_name(clean_name(name))
return bool(self.bucket.get_blob(name))

Expand Down Expand Up @@ -333,9 +335,3 @@ def url(self, name, parameters=None):
params[key] = value

return blob.generate_signed_url(**params)

def get_available_name(self, name, max_length=None):
name = clean_name(name)
if self.file_overwrite:
return get_available_overwrite_name(name, max_length)
return super().get_available_name(name, max_length)
11 changes: 3 additions & 8 deletions storages/backends/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from storages.utils import ReadBytesWrapper
from storages.utils import check_location
from storages.utils import clean_name
from storages.utils import get_available_overwrite_name
from storages.utils import is_seekable
from storages.utils import lookup_env
from storages.utils import safe_join
Expand Down Expand Up @@ -580,6 +579,9 @@ def delete(self, name):
raise

def exists(self, name):
if self.file_overwrite:
return False

name = self._normalize_name(clean_name(name))
try:
self.connection.meta.client.head_object(Bucket=self.bucket_name, Key=name)
Expand Down Expand Up @@ -696,13 +698,6 @@ def url(self, name, parameters=None, expire=None, http_method=None):
)
return url

def get_available_name(self, name, max_length=None):
"""Overwrite existing file with the same name."""
name = clean_name(name)
if self.file_overwrite:
return get_available_overwrite_name(name, max_length)
return super().get_available_name(name, max_length)


class S3StaticStorage(S3Storage):
"""Querystring auth must be disabled so that url() returns a consistent output."""
Expand Down
20 changes: 0 additions & 20 deletions storages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import SuspiciousFileOperation
from django.core.files.utils import FileProxyMixin
from django.utils.encoding import force_bytes

Expand Down Expand Up @@ -120,25 +119,6 @@ def lookup_env(names):
return value


def get_available_overwrite_name(name, max_length):
if max_length is None or len(name) <= max_length:
return name

# Adapted from Django
dir_name, file_name = os.path.split(name)
file_root, file_ext = os.path.splitext(file_name)
truncation = len(name) - max_length

file_root = file_root[:-truncation]
if not file_root:
raise SuspiciousFileOperation(
'Storage tried to truncate away entire filename "%s". '
"Please make sure that the corresponding file field "
'allows sufficient "max_length".' % name
)
return os.path.join(dir_name, "{}{}".format(file_root, file_ext))


def is_seekable(file_object):
return not hasattr(file_object, "seekable") or file_object.seekable()

Expand Down
75 changes: 9 additions & 66 deletions tests/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from datetime import timedelta
from unittest import mock

import django
from azure.storage.blob import BlobProperties
from django.core.exceptions import SuspiciousOperation
from django.core.files.base import ContentFile
from django.test import TestCase
from django.test import override_settings
Expand Down Expand Up @@ -70,65 +68,6 @@ def test_get_valid_path_idempotency(self):
self.storage._get_valid_path(some_path),
)

def test_get_available_name(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.exists.side_effect = [True, False]
self.storage._client.get_blob_client.return_value = client_mock
name = self.storage.get_available_name("foo.txt")
self.assertTrue(name.startswith("foo_"))
self.assertTrue(name.endswith(".txt"))
self.assertTrue(len(name) > len("foo.txt"))
self.assertEqual(client_mock.exists.call_count, 2)

def test_get_available_name_first(self):
self.storage.overwrite_files = False
client_mock = mock.MagicMock()
client_mock.exists.return_value = False
self.storage._client.get_blob_client.return_value = client_mock
self.assertEqual(
self.storage.get_available_name("foo bar baz.txt"), "foo bar baz.txt"
)
self.assertEqual(client_mock.exists.call_count, 1)

def test_get_available_name_max_len(self):
self.storage.overwrite_files = False
# if you wonder why this is, file-system
# storage will raise when file name is too long as well,
# the form should validate this
client_mock = mock.MagicMock()
client_mock.exists.side_effect = [True, False]
self.storage._client.get_blob_client.return_value = client_mock
self.assertRaises(ValueError, self.storage.get_available_name, "a" * 1025)
name = self.storage.get_available_name(
"a" * 1000, max_length=100
) # max_len == 1024
self.assertEqual(len(name), 100)
self.assertTrue("_" in name)
self.assertEqual(client_mock.exists.call_count, 2)

def test_get_available_invalid(self):
self.storage.overwrite_files = False
self.storage._client.exists.return_value = False
if django.VERSION[:2] == (3, 0):
# Django 2.2.21 added this security fix:
# https://docs.djangoproject.com/en/3.2/releases/2.2.21/#cve-2021-31542-potential-directory-traversal-via-uploaded-files
# It raises SuspiciousOperation before we get to our ValueError.
# The fix wasn't applied to 3.0 (no longer in support), but was applied to
# 3.1 & 3.2.
self.assertRaises(ValueError, self.storage.get_available_name, "")
self.assertRaises(ValueError, self.storage.get_available_name, "/")
self.assertRaises(ValueError, self.storage.get_available_name, ".")
self.assertRaises(ValueError, self.storage.get_available_name, "///")
else:
self.assertRaises(SuspiciousOperation, self.storage.get_available_name, "")
self.assertRaises(SuspiciousOperation, self.storage.get_available_name, "/")
self.assertRaises(SuspiciousOperation, self.storage.get_available_name, ".")
self.assertRaises(
SuspiciousOperation, self.storage.get_available_name, "///"
)
self.assertRaises(ValueError, self.storage.get_available_name, "...")

def test_url(self):
blob_mock = mock.MagicMock()
blob_mock.url = "https://ret_foo.blob.core.windows.net/test/some%20blob"
Expand Down Expand Up @@ -357,11 +296,15 @@ def test_storage_open_write(self):
)

def test_storage_exists(self):
blob_name = "blob"
client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
self.assertTrue(self.storage.exists(blob_name))
self.assertEqual(client_mock.exists.call_count, 1)
overwrite_files = [True, False]
for owf in overwrite_files:
self.storage.overwrite_files = owf
client_mock = mock.MagicMock()
self.storage._client.get_blob_client.return_value = client_mock
assert_ = self.assertFalse if owf else self.assertTrue
call_count = 0 if owf else 1
assert_(self.storage.exists("blob"))
self.assertEqual(client_mock.exists.call_count, call_count)

def test_delete_blob(self):
self.storage.delete("name")
Expand Down
5 changes: 5 additions & 0 deletions tests/test_dropbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def test_not_exists(self, *args):
exists = self.storage.exists("bar")
self.assertFalse(exists)

@mock.patch("dropbox.Dropbox.files_get_metadata", return_value=[FILE_METADATA_MOCK])
def test_exists_overwrite_mode(self, *args):
self.storage.write_mode = "overwrite"
self.assertFalse(self.storage.exists("foo"))

@mock.patch("dropbox.Dropbox.files_list_folder", return_value=FILES_MOCK)
def test_listdir(self, *args):
dirs, files = self.storage.listdir("/")
Expand Down
5 changes: 5 additions & 0 deletions tests/test_gcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_delete(self):
)

def test_exists(self):
self.storage.file_overwrite = False
self.storage._bucket = mock.MagicMock()
self.assertTrue(self.storage.exists(self.filename))
self.storage._bucket.get_blob.assert_called_with(self.filename)
Expand All @@ -186,6 +187,10 @@ def test_exists_bucket(self):
# exists('') should return True if the bucket exists
self.assertTrue(self.storage.exists(""))

def test_exists_file_overwrite(self):
self.storage.file_overwrite = True
self.assertFalse(self.storage.exists(self.filename))

def test_listdir(self):
file_names = ["some/path/1.txt", "2.txt", "other/path/3.txt", "4.txt"]
subdir = ""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,15 @@ def test_storage_write_beyond_buffer_size(self):
)

def test_storage_exists(self):
self.storage.file_overwrite = False
self.assertTrue(self.storage.exists("file.txt"))
self.storage.connection.meta.client.head_object.assert_called_with(
Bucket=self.storage.bucket_name,
Key="file.txt",
)

def test_storage_exists_false(self):
self.storage.file_overwrite = False
self.storage.connection.meta.client.head_object.side_effect = ClientError(
{"Error": {}, "ResponseMetadata": {"HTTPStatusCode": 404}},
"HeadObject",
Expand All @@ -514,6 +516,7 @@ def test_storage_exists_false(self):
)

def test_storage_exists_other_error_reraise(self):
self.storage.file_overwrite = False
self.storage.connection.meta.client.head_object.side_effect = ClientError(
{"Error": {}, "ResponseMetadata": {"HTTPStatusCode": 403}},
"HeadObject",
Expand All @@ -525,6 +528,10 @@ def test_storage_exists_other_error_reraise(self):
cm.exception.response["ResponseMetadata"]["HTTPStatusCode"], 403
)

def test_storage_exists_overwrite(self):
self.storage.file_overwrite = True
self.assertFalse(self.storage.exists("foo"))

def test_storage_delete(self):
self.storage.delete("path/to/file.txt")
self.storage.bucket.Object.assert_called_with("path/to/file.txt")
Expand Down
25 changes: 0 additions & 25 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import pathlib

from django.conf import settings
from django.core.exceptions import SuspiciousFileOperation
from django.test import TestCase

from storages import utils
from storages.utils import get_available_overwrite_name as gaon


class SettingTest(TestCase):
Expand Down Expand Up @@ -118,29 +116,6 @@ def test_with_base_url_join_nothing(self):
self.assertEqual(path, "base_url/")


class TestGetAvailableOverwriteName(TestCase):
def test_maxlength_is_none(self):
name = "superlong/file/with/path.txt"
self.assertEqual(gaon(name, None), name)

def test_maxlength_equals_name(self):
name = "parent/child.txt"
self.assertEqual(gaon(name, len(name)), name)

def test_maxlength_is_greater_than_name(self):
name = "parent/child.txt"
self.assertEqual(gaon(name, len(name) + 1), name)

def test_maxlength_less_than_name(self):
name = "parent/child.txt"
self.assertEqual(gaon(name, len(name) - 1), "parent/chil.txt")

def test_truncates_away_filename_raises(self):
name = "parent/child.txt"
with self.assertRaises(SuspiciousFileOperation):
gaon(name, len(name) - 5)


class TestReadBytesWrapper(TestCase):
def test_with_bytes_file(self):
file = io.BytesIO(b"abcd")
Expand Down
Loading