From 746efc4c0639db4aad1080713fb347ba35674a3e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 19 Nov 2024 17:07:04 -0800 Subject: [PATCH] fix: Fix security vuln with file upload (#2067) --- .github/workflows/tests.yml | 3 +- letta/agent_store/db.py | 7 ++- letta/constants.py | 3 + letta/schemas/letta_base.py | 2 +- letta/server/rest_api/routers/v1/sources.py | 12 +++- letta/utils.py | 39 ++++++++++++ poetry.lock | 18 +++++- pyproject.toml | 1 + tests/helpers/client_helper.py | 2 +- tests/test_client.py | 6 +- tests/test_utils.py | 66 +++++++++++++++++++++ 11 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 tests/test_utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index adda0fd42a..1d658c660e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,6 +29,7 @@ jobs: - "test_o1_agent.py" - "test_tool_rule_solver.py" - "test_agent_tool_graph.py" + - "test_utils.py" services: qdrant: image: qdrant/qdrant @@ -131,4 +132,4 @@ jobs: LETTA_SERVER_PASS: test_server_token PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} run: | - poetry run pytest -s -vv -k "not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests + poetry run pytest -s -vv -k "not test_utils.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index b9f3b40e4e..b9726e6a8c 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -380,9 +380,10 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None) else: raise ValueError(f"Table type {table_type} not implemented") - for c in self.db_model.__table__.columns: - if c.name == "embedding": - assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}" + if settings.pg_uri: + for c in self.db_model.__table__.columns: + if c.name == "embedding": + assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}" from letta.server.server import db_context diff --git a/letta/constants.py b/letta/constants.py index 0539477a46..436b9dcc4a 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -158,3 +158,6 @@ # TODO Is this config or constant? CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000 CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000 + +MAX_FILENAME_LENGTH = 255 +RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"} diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index 3855a3ab6a..9f1af6e99a 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -77,6 +77,6 @@ def allow_bare_uuids(cls, v, values): """ _ = values # for SCA if isinstance(v, UUID): - logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!") + logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!") return f"{cls.__id_prefix__}-{v}" return v diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index c0558b0ca5..48f05c565a 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -18,6 +18,7 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.utils import sanitize_filename # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -170,7 +171,7 @@ def upload_file_to_source( server.ms.create_job(job) # create background task - background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes) + background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes) # return job information job = server.ms.get_job(job_id=job_id) @@ -227,10 +228,15 @@ def delete_file_from_source( def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): - # write the file to a temporary directory (deleted after the context manager exits) + # Create a temporary directory (deleted after the context manager exits) with tempfile.TemporaryDirectory() as tmpdirname: - file_path = os.path.join(str(tmpdirname), str(file.filename)) + # Sanitize the filename + sanitized_filename = sanitize_filename(file.filename) + file_path = os.path.join(tmpdirname, sanitized_filename) + + # Write the file to the sanitized path with open(file_path, "wb") as buffer: buffer.write(bytes) + # Pass the file to load_file_to_source server.load_file_to_source(source_id, file_path, job_id) diff --git a/letta/utils.py b/letta/utils.py index c85f4ef818..a2f65111b9 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -21,6 +21,7 @@ import demjson3 as demjson import pytz import tiktoken +from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename import letta from letta.constants import ( @@ -29,6 +30,7 @@ CORE_MEMORY_PERSONA_CHAR_LIMIT, FUNCTION_RETURN_CHAR_LIMIT, LETTA_DIR, + MAX_FILENAME_LENGTH, TOOL_CALL_ID_MAX_LEN, ) from letta.schemas.openai.chat_completion_response import ChatCompletionResponse @@ -1071,3 +1073,40 @@ def safe_serializer(obj): def json_loads(data): return json.loads(data, strict=False) + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize the given filename to prevent directory traversal, invalid characters, + and reserved names while ensuring it fits within the maximum length allowed by the filesystem. + + Parameters: + filename (str): The user-provided filename. + + Returns: + str: A sanitized filename that is unique and safe for use. + """ + # Extract the base filename to avoid directory components + filename = os.path.basename(filename) + + # Split the base and extension + base, ext = os.path.splitext(filename) + + # External sanitization library + base = pathvalidate_sanitize_filename(base) + + # Cannot start with a period + if base.startswith("."): + raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'") + + # Truncate the base name to fit within the maximum allowed length + max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_` + if len(base) > max_base_length: + base = base[:max_base_length] + + # Append a unique UUID suffix for uniqueness + unique_suffix = uuid.uuid4().hex + sanitized_filename = f"{base}_{unique_suffix}{ext}" + + # Return the sanitized filename + return sanitized_filename diff --git a/poetry.lock b/poetry.lock index 33e5bd24f6..fef0037567 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4905,6 +4905,22 @@ files = [ {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, ] +[[package]] +name = "pathvalidate" +version = "3.2.1" +description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathvalidate-3.2.1-py3-none-any.whl", hash = "sha256:9a6255eb8f63c9e2135b9be97a5ce08f10230128c4ae7b3e935378b82b22c4c9"}, + {file = "pathvalidate-3.2.1.tar.gz", hash = "sha256:f5d07b1e2374187040612a1fcd2bcb2919f8db180df254c9581bb90bf903377d"}, +] + +[package.extras] +docs = ["Sphinx (>=2.4)", "sphinx-rtd-theme (>=1.2.2)", "urllib3 (<2)"] +readme = ["path (>=13,<17)", "readmemaker (>=1.1.0)"] +test = ["Faker (>=1.0.8)", "allpairspy (>=2)", "click (>=6.2)", "pytest (>=6.0.1)", "pytest-md-report (>=0.6.2)"] + [[package]] name = "pexpect" version = "4.9.0" @@ -8494,4 +8510,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "570c482aed9ff66761ac47b8b7e1ca06525d4e5084791723380101217e163500" +content-hash = "5aef7fe9900da5d0fefbb0ce4f4f65b565f1967826f840138cfdd59444fd7330" diff --git a/pyproject.toml b/pyproject.toml index 8f9be5e14b..00ba87f157 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ alembic = "^1.13.3" pyhumps = "^3.8.0" psycopg2 = "^2.9.10" psycopg2-binary = "^2.9.10" +pathvalidate = "^3.2.1" [tool.poetry.extras] #local = ["llama-index-embeddings-huggingface"] diff --git a/tests/helpers/client_helper.py b/tests/helpers/client_helper.py index ac37b697b6..e7cce8ef74 100644 --- a/tests/helpers/client_helper.py +++ b/tests/helpers/client_helper.py @@ -20,7 +20,7 @@ def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Sou assert active_jobs[0].metadata_["source_id"] == source.id # wait for job to finish (with timeout) - timeout = 120 + timeout = 240 start_time = time.time() while True: status = client.get_job(upload_job.id).status diff --git a/tests/test_client.py b/tests/test_client.py index 7f5f095c82..56bbf9a632 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ import os +import re import threading import time import uuid @@ -435,7 +436,10 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): # Get the memgpt paper file = files[0] - assert file.file_name == "memgpt_paper.pdf" + # Assert the filename matches the pattern + pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$") + assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern." + assert file.source_id == source.id diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..904e903e74 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,66 @@ +import pytest + +from letta.constants import MAX_FILENAME_LENGTH +from letta.utils import sanitize_filename + + +def test_valid_filename(): + filename = "valid_filename.txt" + sanitized = sanitize_filename(filename) + assert sanitized.startswith("valid_filename_") + assert sanitized.endswith(".txt") + + +def test_filename_with_special_characters(): + filename = "invalid:/<>?*ƒfilename.txt" + sanitized = sanitize_filename(filename) + assert sanitized.startswith("ƒfilename_") + assert sanitized.endswith(".txt") + + +def test_null_byte_in_filename(): + filename = "valid\0filename.txt" + sanitized = sanitize_filename(filename) + assert "\0" not in sanitized + assert sanitized.startswith("validfilename_") + assert sanitized.endswith(".txt") + + +def test_path_traversal_characters(): + filename = "../../etc/passwd" + sanitized = sanitize_filename(filename) + assert sanitized.startswith("passwd_") + assert len(sanitized) <= MAX_FILENAME_LENGTH + + +def test_empty_filename(): + sanitized = sanitize_filename("") + assert sanitized.startswith("_") + + +def test_dot_as_filename(): + with pytest.raises(ValueError, match="Invalid filename"): + sanitize_filename(".") + + +def test_dotdot_as_filename(): + with pytest.raises(ValueError, match="Invalid filename"): + sanitize_filename("..") + + +def test_long_filename(): + filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt" + sanitized = sanitize_filename(filename) + assert len(sanitized) <= MAX_FILENAME_LENGTH + assert sanitized.endswith(".txt") + + +def test_unique_filenames(): + filename = "duplicate.txt" + sanitized1 = sanitize_filename(filename) + sanitized2 = sanitize_filename(filename) + assert sanitized1 != sanitized2 + assert sanitized1.startswith("duplicate_") + assert sanitized2.startswith("duplicate_") + assert sanitized1.endswith(".txt") + assert sanitized2.endswith(".txt")