From f66c492f1e2e222d1541ee1bccde254be3d6af48 Mon Sep 17 00:00:00 2001 From: Chris Clark Date: Sat, 27 Jul 2024 12:53:41 -0400 Subject: [PATCH] better secure filenamesgst --- explorer/ee/db_connections/create_sqlite.py | 14 +++++-- explorer/tests/test_create_sqlite.py | 41 ++++++++++++++++++++- explorer/tests/test_utils.py | 34 ++++++++++++++++- explorer/utils.py | 12 +++++- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/explorer/ee/db_connections/create_sqlite.py b/explorer/ee/db_connections/create_sqlite.py index 4cd64694..e65e55f8 100644 --- a/explorer/ee/db_connections/create_sqlite.py +++ b/explorer/ee/db_connections/create_sqlite.py @@ -6,10 +6,9 @@ from explorer.ee.db_connections.utils import pandas_to_sqlite, uploaded_db_local_path -def parse_to_sqlite(file, append_conn=None, user_id=None) -> (BytesIO, str): - - table_name, _ = os.path.splitext(file.name) - table_name = secure_filename(table_name) +def get_names(file, append_conn=None, user_id=None): + s_filename = secure_filename(file.name) + table_name, _ = os.path.splitext(s_filename) # f_name represents the filename of both the sqlite DB on S3, and on the local filesystem. # If we are appending to an existing data source, then we re-use the same name. @@ -19,6 +18,13 @@ def parse_to_sqlite(file, append_conn=None, user_id=None) -> (BytesIO, str): else: f_name = f"{table_name}_{user_id}.db" + return table_name, f_name + + +def parse_to_sqlite(file, append_conn=None, user_id=None) -> (BytesIO, str): + + table_name, f_name = get_names(file, append_conn, user_id) + # When appending, make sure the database exists locally so that we can write to it if append_conn: append_conn.download_sqlite_if_needed() diff --git a/explorer/tests/test_create_sqlite.py b/explorer/tests/test_create_sqlite.py index 3d40c614..4efa9eb7 100644 --- a/explorer/tests/test_create_sqlite.py +++ b/explorer/tests/test_create_sqlite.py @@ -1,8 +1,8 @@ from django.test import TestCase from django.core.files.uploadedfile import SimpleUploadedFile -from unittest import skipIf +from unittest import skipIf, mock from explorer.app_settings import EXPLORER_USER_UPLOADS_ENABLED -from explorer.ee.db_connections.create_sqlite import parse_to_sqlite +from explorer.ee.db_connections.create_sqlite import parse_to_sqlite, get_names import os import sqlite3 @@ -45,3 +45,40 @@ def test_parse_to_sqlite_with_no_parser(self): self.assertEqual(rows[0], ("chris", "cto")) self.assertEqual(name, "name_1.db") + + +class TestGetNames(TestCase): + def setUp(self): + # Mock file object + self.mock_file = mock.MagicMock() + self.mock_file.name = "test file name.txt" + + # Mock append_conn object + self.mock_append_conn = mock.MagicMock() + self.mock_append_conn.name = "/path/to/existing_db.sqlite" + + def test_no_append_conn(self): + table_name, f_name = get_names(self.mock_file, append_conn=None, user_id=123) + self.assertEqual(table_name, "test_file_name") + self.assertEqual(f_name, "test_file_name_123.db") + + def test_with_append_conn(self): + table_name, f_name = get_names(self.mock_file, append_conn=self.mock_append_conn, user_id=123) + self.assertEqual(table_name, "test_file_name") + self.assertEqual(f_name, "existing_db.sqlite") + + def test_secure_filename(self): + self.mock_file.name = "测试文件.txt" + table_name, f_name = get_names(self.mock_file, append_conn=None, user_id=123) + self.assertEqual(table_name, "_") + self.assertEqual(f_name, "__123.db") + + def test_empty_filename(self): + self.mock_file.name = ".txt" + with self.assertRaises(ValueError): + get_names(self.mock_file, append_conn=None, user_id=123) + + def test_invalid_extension(self): + self.mock_file.name = "filename.exe" + with self.assertRaises(ValueError): + get_names(self.mock_file, append_conn=None, user_id=123) diff --git a/explorer/tests/test_utils.py b/explorer/tests/test_utils.py index 1fa2ebfa..e160dc6b 100644 --- a/explorer/tests/test_utils.py +++ b/explorer/tests/test_utils.py @@ -6,7 +6,7 @@ from explorer.tests.factories import SimpleQueryFactory from explorer.utils import ( EXPLORER_PARAM_TOKEN, extract_params, get_params_for_url, get_params_from_request, param, passes_blacklist, - shared_dict_update, swap_params, + shared_dict_update, swap_params, secure_filename ) @@ -271,3 +271,35 @@ def test_only_registered_connections_are_in_connections(self): from explorer.connections import connections self.assertTrue(EXPLORER_DEFAULT_CONNECTION in connections()) self.assertNotEqual(len(connections()), len([c for c in djcs])) + + +class TestSecureFilename(TestCase): + def test_basic_ascii(self): + self.assertEqual(secure_filename("simple_file.txt"), "simple_file.txt") + + def test_special_characters(self): + self.assertEqual(secure_filename("file@name!.txt"), "file_name.txt") + + def test_leading_trailing_underscores(self): + self.assertEqual(secure_filename("_leading.txt"), "leading.txt") + self.assertEqual(secure_filename("trailing_.txt"), "trailing.txt") + self.assertEqual(secure_filename(".__filename__.txt"), "filename.txt") + + def test_unicode_characters(self): + self.assertEqual(secure_filename("fïléñâmé.txt"), "filename.txt") + self.assertEqual(secure_filename("测试文件.txt"), "_.txt") + + def test_empty_filename(self): + with self.assertRaises(ValueError): + secure_filename("") + + def test_bad_extension(self): + with self.assertRaises(ValueError): + secure_filename("foo.xyz") + + def test_empty_extension(self): + with self.assertRaises(ValueError): + secure_filename("foo.") + + def test_spaces(self): + self.assertEqual(secure_filename("file name.txt"), "file_name.txt") diff --git a/explorer/utils.py b/explorer/utils.py index c127d03c..b6454310 100644 --- a/explorer/utils.py +++ b/explorer/utils.py @@ -1,4 +1,5 @@ import re +import os import unicodedata from collections import deque from typing import Iterable, Tuple @@ -254,6 +255,15 @@ def is_xls_writer_available(): def secure_filename(filename): + filename, ext = os.path.splitext(filename) + if not filename and not ext: + raise ValueError("Filename or extension cannot be blank") + if ext.lower() not in [".db", ".sqlite", ".sqlite3", ".csv", ".json", ".txt"]: + raise ValueError(f"Invalid extension: {ext}") + filename = unicodedata.normalize("NFKD", filename).encode("ascii", "ignore").decode("ascii") filename = re.sub(r"[^a-zA-Z0-9_.-]", "_", filename) - return filename.strip("._") + filename = filename.strip("._") + if not filename: # If filename becomes empty, replace it with an underscore + filename = "_" + return f"{filename}{ext}"