Skip to content

Commit

Permalink
fix: Fix security vuln with file upload (#2067)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Nov 20, 2024
1 parent 794425a commit 746efc4
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
- "test_o1_agent.py"
- "test_tool_rule_solver.py"
- "test_agent_tool_graph.py"
- "test_utils.py"
services:
qdrant:
image: qdrant/qdrant
Expand Down Expand Up @@ -131,4 +132,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
poetry run pytest -s -vv -k "not test_utils.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
7 changes: 4 additions & 3 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,10 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
else:
raise ValueError(f"Table type {table_type} not implemented")

for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
if settings.pg_uri:
for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"

from letta.server.server import db_context

Expand Down
3 changes: 3 additions & 0 deletions letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@
# TODO Is this config or constant?
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000

MAX_FILENAME_LENGTH = 255
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
2 changes: 1 addition & 1 deletion letta/schemas/letta_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@ def allow_bare_uuids(cls, v, values):
"""
_ = values # for SCA
if isinstance(v, UUID):
logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!")
logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
return f"{cls.__id_prefix__}-{v}"
return v
12 changes: 9 additions & 3 deletions letta/server/rest_api/routers/v1/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.utils import sanitize_filename

# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally

Expand Down Expand Up @@ -170,7 +171,7 @@ def upload_file_to_source(
server.ms.create_job(job)

# create background task
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes)

# return job information
job = server.ms.get_job(job_id=job_id)
Expand Down Expand Up @@ -227,10 +228,15 @@ def delete_file_from_source(


def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
# write the file to a temporary directory (deleted after the context manager exits)
# Create a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(str(tmpdirname), str(file.filename))
# Sanitize the filename
sanitized_filename = sanitize_filename(file.filename)
file_path = os.path.join(tmpdirname, sanitized_filename)

# Write the file to the sanitized path
with open(file_path, "wb") as buffer:
buffer.write(bytes)

# Pass the file to load_file_to_source
server.load_file_to_source(source_id, file_path, job_id)
39 changes: 39 additions & 0 deletions letta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import demjson3 as demjson
import pytz
import tiktoken
from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename

import letta
from letta.constants import (
Expand All @@ -29,6 +30,7 @@
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FUNCTION_RETURN_CHAR_LIMIT,
LETTA_DIR,
MAX_FILENAME_LENGTH,
TOOL_CALL_ID_MAX_LEN,
)
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
Expand Down Expand Up @@ -1071,3 +1073,40 @@ def safe_serializer(obj):

def json_loads(data):
return json.loads(data, strict=False)


def sanitize_filename(filename: str) -> str:
"""
Sanitize the given filename to prevent directory traversal, invalid characters,
and reserved names while ensuring it fits within the maximum length allowed by the filesystem.
Parameters:
filename (str): The user-provided filename.
Returns:
str: A sanitized filename that is unique and safe for use.
"""
# Extract the base filename to avoid directory components
filename = os.path.basename(filename)

# Split the base and extension
base, ext = os.path.splitext(filename)

# External sanitization library
base = pathvalidate_sanitize_filename(base)

# Cannot start with a period
if base.startswith("."):
raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'")

# Truncate the base name to fit within the maximum allowed length
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
if len(base) > max_base_length:
base = base[:max_base_length]

# Append a unique UUID suffix for uniqueness
unique_suffix = uuid.uuid4().hex
sanitized_filename = f"{base}_{unique_suffix}{ext}"

# Return the sanitized filename
return sanitized_filename
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ alembic = "^1.13.3"
pyhumps = "^3.8.0"
psycopg2 = "^2.9.10"
psycopg2-binary = "^2.9.10"
pathvalidate = "^3.2.1"

[tool.poetry.extras]
#local = ["llama-index-embeddings-huggingface"]
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/client_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Sou
assert active_jobs[0].metadata_["source_id"] == source.id

# wait for job to finish (with timeout)
timeout = 120
timeout = 240
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
Expand Down
6 changes: 5 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import threading
import time
import uuid
Expand Down Expand Up @@ -435,7 +436,10 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):

# Get the memgpt paper
file = files[0]
assert file.file_name == "memgpt_paper.pdf"
# Assert the filename matches the pattern
pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$")
assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern."

assert file.source_id == source.id


Expand Down
66 changes: 66 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

from letta.constants import MAX_FILENAME_LENGTH
from letta.utils import sanitize_filename


def test_valid_filename():
filename = "valid_filename.txt"
sanitized = sanitize_filename(filename)
assert sanitized.startswith("valid_filename_")
assert sanitized.endswith(".txt")


def test_filename_with_special_characters():
filename = "invalid:/<>?*ƒfilename.txt"
sanitized = sanitize_filename(filename)
assert sanitized.startswith("ƒfilename_")
assert sanitized.endswith(".txt")


def test_null_byte_in_filename():
filename = "valid\0filename.txt"
sanitized = sanitize_filename(filename)
assert "\0" not in sanitized
assert sanitized.startswith("validfilename_")
assert sanitized.endswith(".txt")


def test_path_traversal_characters():
filename = "../../etc/passwd"
sanitized = sanitize_filename(filename)
assert sanitized.startswith("passwd_")
assert len(sanitized) <= MAX_FILENAME_LENGTH


def test_empty_filename():
sanitized = sanitize_filename("")
assert sanitized.startswith("_")


def test_dot_as_filename():
with pytest.raises(ValueError, match="Invalid filename"):
sanitize_filename(".")


def test_dotdot_as_filename():
with pytest.raises(ValueError, match="Invalid filename"):
sanitize_filename("..")


def test_long_filename():
filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt"
sanitized = sanitize_filename(filename)
assert len(sanitized) <= MAX_FILENAME_LENGTH
assert sanitized.endswith(".txt")


def test_unique_filenames():
filename = "duplicate.txt"
sanitized1 = sanitize_filename(filename)
sanitized2 = sanitize_filename(filename)
assert sanitized1 != sanitized2
assert sanitized1.startswith("duplicate_")
assert sanitized2.startswith("duplicate_")
assert sanitized1.endswith(".txt")
assert sanitized2.endswith(".txt")

0 comments on commit 746efc4

Please sign in to comment.