Skip to content

Commit

Permalink
v0.6.0 - Harden Security (#55)
Browse files Browse the repository at this point in the history
* added additional logging

* disable account creation and file uploads temporarily

* system now only allows user file upload if user is whitelisted and user has uploaded less than daily amount of bytes

* hard limit on total number of users allowed to be created as another form of security

* added USER_LIMIT to testing env file

* added invoke task show-users-table

* updated show-users-table invoke task

* fix: user_limit -> users_limit

* removed constant DAILY_UPLOAD_LIMIT_BYTES from operation_validator and instead retrieving value from get_settings()

* wip

* Polyfactory for Creating Mock Users (#59)

* wip: trying to isolate users from one another in test scenarios. todo: generate mock users using polyfactory

* upgraded packages

* fixed all tests except test_user_creation_limit

* added new factories.py file

* updated mypy ignore comment

* updated project version to 0.6.0

* updated requirements and added pyinvoke task for updating requirements.txt with the deps from uv's lockfile
  • Loading branch information
fullerzz authored Sep 7, 2024
1 parent 423c65f commit 108c4ee
Show file tree
Hide file tree
Showing 19 changed files with 814 additions and 372 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "smolvault"
version = "0.5.0"
version = "0.6.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
Expand All @@ -27,13 +27,15 @@ dev-dependencies = [
"boto3-stubs[essential]>=1.35.2",
"pre-commit>=3.8.0",
"pytest>=8.3.2",
"pytest-asyncio>=0.23.8",
"pytest-cov>=5.0.0",
"moto[all]>=5.0.13",
"invoke>=2.2.0",
"rich>=13.7.1",
"types-pyjwt>=1.7.1",
"httpx>=0.27.0",
"pytest-sugar>=1.0.0",
"anyio>=4.4.0",
"polyfactory>=2.16.2",
]

[tool.pytest.ini_options]
Expand Down
292 changes: 224 additions & 68 deletions requirements.txt

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion scripts/start_app.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#!/bin/bash

source .venv/bin/activate
hypercorn src.smolvault.main:app -b 0.0.0.0 --debug --log-config=logging.conf --log-level=DEBUG --access-logfile=hypercorn.access.log --error-logfile=hypercorn.error.log --keep-alive=120 --workers=2
hypercorn src.smolvault.main:app -b 0.0.0.0 --debug \
--log-config=logging.conf --log-level=DEBUG \
--access-logfile=hypercorn.access.log \
--error-logfile=hypercorn.error.log \
--keep-alive=120 --workers=2
5 changes: 4 additions & 1 deletion scripts/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ export SMOLVAULT_BUCKET="test-bucket"
export SMOLVAULT_DB="test.db"
export SMOLVAULT_CACHE="./uploads/"
export AUTH_SECRET_KEY="09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" # key from FastAPI docs to use in tests
export DAILY_UPLOAD_LIMIT_BYTES="500000"
export USERS_LIMIT="20"
export USER_WHITELIST="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20"

# remove test db if it exists
if [ -f $SMOLVAULT_DB ]; then
Expand All @@ -18,4 +21,4 @@ fi
# create local cache dir
mkdir uploads

pytest -vvv tests/
pytest -vvv tests
2 changes: 1 addition & 1 deletion src/smolvault/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NewUserDTO(BaseModel):
full_name: str
password: SecretStr

@computed_field # type: ignore
@computed_field # type: ignore[prop-decorator]
@cached_property
def hashed_password(self) -> str:
return bcrypt.hashpw(self.password.get_secret_value().encode(), bcrypt.gensalt()).decode()
Expand Down
15 changes: 14 additions & 1 deletion src/smolvault/clients/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from datetime import datetime

from sqlmodel import Field, Session, SQLModel, create_engine, select

Expand Down Expand Up @@ -59,9 +60,15 @@ def add_metadata(self, file_upload: FileUploadDTO, key: str) -> None:
session.add(FileTag(tag_name=tag, file_id=file_metadata.id))
session.commit()

def get_all_metadata(self, user_id: int) -> Sequence[FileMetadataRecord]:
def get_all_metadata(
self, user_id: int, start_time: datetime | None = None, end_time: datetime | None = None
) -> Sequence[FileMetadataRecord]:
with Session(self.engine) as session:
statement = select(FileMetadataRecord).where(FileMetadataRecord.user_id == user_id)
if start_time:
statement = statement.where(FileMetadataRecord.upload_timestamp >= start_time.isoformat())
if end_time:
statement = statement.where(FileMetadataRecord.upload_timestamp <= end_time.isoformat())
results = session.exec(statement)
return results.fetchall()

Expand Down Expand Up @@ -106,6 +113,12 @@ def get_user(self, username: str) -> UserInfo | None:
statement = select(UserInfo).where(UserInfo.username == username)
return session.exec(statement).first()

def get_user_count(self) -> int:
with Session(self.engine) as session:
statement = select(UserInfo)
results = session.exec(statement)
return len(results.fetchall())

def add_user(self, user: NewUserDTO) -> None:
user_info = UserInfo(
username=user.username,
Expand Down
3 changes: 3 additions & 0 deletions src/smolvault/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class Settings(BaseSettings):
smolvault_db: str
smolvault_cache: str
auth_secret_key: str
user_whitelist: str
users_limit: int
daily_upload_limit_bytes: int

model_config = SettingsConfigDict(env_file=".env")

Expand Down
59 changes: 42 additions & 17 deletions src/smolvault/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,20 @@
from logging.handlers import RotatingFileHandler
from typing import Annotated

from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
File,
Form,
HTTPException,
UploadFile,
)
from fastapi import BackgroundTasks, Depends, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse, Response
from fastapi.security import OAuth2PasswordRequestForm

from smolvault.auth.decoder import (
authenticate_user,
create_access_token,
get_current_user,
)
from smolvault.auth.decoder import authenticate_user, create_access_token, get_current_user
from smolvault.auth.models import NewUserDTO, Token, User
from smolvault.cache.cache_manager import CacheManager
from smolvault.clients.aws import S3Client
from smolvault.clients.database import DatabaseClient, FileMetadataRecord
from smolvault.config import Settings, get_settings
from smolvault.models import FileMetadata, FileTagsDTO, FileUploadDTO
from smolvault.validators.operation_validator import UploadValidator, UserCreationValidator

logging.basicConfig(
handlers=[
Expand Down Expand Up @@ -66,35 +55,58 @@ async def read_root(current_user: Annotated[User, Depends(get_current_user)]) ->

@app.post("/users/new")
async def create_user(
user: NewUserDTO, db_client: Annotated[DatabaseClient, Depends(DatabaseClient)]
user: NewUserDTO,
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
op_validator: Annotated[UserCreationValidator, Depends(UserCreationValidator)],
) -> dict[str, str]:
db_client.add_user(user)
return {"username": user.username}
logger.info("Received new user creation request for %s", user.username)
if op_validator.user_creation_allowed(db_client):
logger.info("Creating new user", extra=user.model_dump(exclude={"password"}))
db_client.add_user(user)
return {"username": user.username}
else:
logger.error("User creation failed. User limit exceeded")
raise HTTPException(
status_code=400,
detail="User limit exceeded",
)


@app.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
) -> Token:
logger.info("Authenticating user %s", form_data.username)
user = authenticate_user(db_client, form_data.username, form_data.password)
if not user:
logger.info("Incorrect username or password for %s", form_data.username)
raise HTTPException(
status_code=400,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
logger.info("User %s authenticated successfully", user.username)
return access_token


@app.post("/file/upload")
async def upload_file(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
op_validator: Annotated[UploadValidator, Depends(UploadValidator)],
file: Annotated[UploadFile, File()],
tags: str | None = Form(default=None),
) -> Response:
logger.info("Received file upload request from %s", current_user.username)
if not op_validator.upload_allowed(current_user.id, db_client):
logger.error("Upload limit exceeded for user %s", current_user.username)
return Response(
content=json.dumps({"error": "Upload limit exceeded"}),
status_code=400,
media_type="application/json",
)
contents = await file.read()
if file.filename is None:
logger.error("Filename not received in request")
Expand All @@ -113,6 +125,7 @@ async def upload_file(
)
object_key = s3_client.upload(data=file_upload)
db_client.add_metadata(file_upload, object_key)
logger.info("File %s uploaded successfully", file_upload.name)
return Response(
content=json.dumps(file_upload.model_dump(exclude={"content", "tags"})),
status_code=201,
Expand All @@ -127,6 +140,7 @@ async def get_file(
filename: str,
background_tasks: BackgroundTasks,
) -> Response:
logger.info("Received file download request for %s from %s", filename, current_user.username)
record = db_client.get_metadata(filename, current_user.id)
if record is None:
logger.info("File not found: %s", filename)
Expand All @@ -152,9 +166,12 @@ async def get_file_metadata(
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
name: str,
) -> FileMetadata | None:
logger.info("Retrieving metadata for file %s requested by %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(urllib.parse.unquote(name), current_user.id)
if record:
logger.info("Retrieved metadata for file %s", name)
return FileMetadata.model_validate(record.model_dump())
logger.info("File metadata for %s not found", name)
return None


Expand All @@ -163,6 +180,7 @@ async def get_files(
current_user: Annotated[User, Depends(get_current_user)],
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
) -> list[FileMetadata]:
logger.info("Retrieving all files for user %s", current_user.username)
raw_metadata = db_client.get_all_metadata(current_user.id)
logger.info("Retrieved %d records from database", len(raw_metadata))
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
Expand All @@ -175,6 +193,7 @@ async def search_files(
db_client: Annotated[DatabaseClient, Depends(DatabaseClient)],
tag: str,
) -> list[FileMetadata]:
logger.info("Retrieving files with tag %s for user %s", tag, current_user.username)
raw_metadata = db_client.select_metadata_by_tag(tag, current_user.id)
logger.info("Retrieved %d records from database with tag %s", len(raw_metadata), tag)
results = [FileMetadata.model_validate(metadata.model_dump()) for metadata in raw_metadata]
Expand All @@ -188,8 +207,10 @@ async def update_file_tags(
name: str,
tags: FileTagsDTO,
) -> Response:
logger.info("Updating tags for file %s requested by %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(name, current_user.id)
if record is None:
logger.info("Tag update failed. File %s not found", name)
return Response(
content=json.dumps({"error": "File not found"}),
status_code=404,
Expand All @@ -199,6 +220,7 @@ async def update_file_tags(
record.tags = tags.tags_str
db_client.update_metadata(record)
file_metadata = FileMetadata.model_validate(record.model_dump())
logger.info("Tags updated for file %s", name)
return Response(
content=json.dumps(
{
Expand All @@ -218,8 +240,10 @@ async def delete_file(
name: str,
background_tasks: BackgroundTasks,
) -> Response:
logger.info("Recieved delete request for file %s from %s", name, current_user.username)
record: FileMetadataRecord | None = db_client.get_metadata(name, current_user.id)
if record is None:
logger.info("File %s not found", name)
return Response(
content=json.dumps({"error": "File not found"}),
status_code=404,
Expand All @@ -229,6 +253,7 @@ async def delete_file(
db_client.delete_metadata(record, current_user.id)
if record.local_path:
background_tasks.add_task(cache.delete_file, record.local_path)
logger.info("File %s deleted successfully", name)
return Response(
content=json.dumps({"message": "File deleted successfully", "record": record.model_dump()}),
status_code=200,
Expand Down
47 changes: 47 additions & 0 deletions src/smolvault/validators/operation_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from datetime import datetime, timedelta

from smolvault.clients.database import DatabaseClient
from smolvault.config import get_settings

logger = logging.getLogger(__name__)


class UploadValidator:
def __init__(self) -> None:
self.settings = get_settings()
self.daily_upload_limit_bytes = self.settings.daily_upload_limit_bytes
self.whitelist = self.settings.user_whitelist.split(",")

def upload_allowed(self, user_id: int, db_client: DatabaseClient) -> bool:
valid = self._uploads_under_limit_prev_24h(user_id, db_client) and self._user_on_whitelist(user_id)
logger.info("Upload allowed result for user %s: %s", user_id, valid)
return valid

def _uploads_under_limit_prev_24h(self, user_id: int, db_client: DatabaseClient) -> bool:
logger.info("Checking upload limit for user %s", user_id)
start_time = datetime.now() - timedelta(days=1)
metadata = db_client.get_all_metadata(user_id, start_time=start_time)
bytes_uploaded = sum([record.size for record in metadata])
logger.info(
"User %s has uploaded %d bytes in the last 24 hours. DAILY_LIMIT: %d",
user_id,
bytes_uploaded,
self.daily_upload_limit_bytes,
)
return bytes_uploaded < self.daily_upload_limit_bytes

def _user_on_whitelist(self, user_id: int) -> bool:
logger.info("Checking whitelist for user %s", user_id)
return str(user_id) in self.whitelist


class UserCreationValidator:
def __init__(self) -> None:
self.settings = get_settings()
self.users_limit = self.settings.users_limit

def user_creation_allowed(self, db_client: DatabaseClient) -> bool:
users: int = db_client.get_user_count()
logger.info("%d users currently in the system", users)
return users < self.users_limit
33 changes: 33 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sqlite3
from datetime import datetime
from typing import Any
from zoneinfo import ZoneInfo

from invoke.context import Context
from invoke.tasks import task
from rich import print
from rich.table import Table


@task
Expand Down Expand Up @@ -36,7 +38,38 @@ def show_table(c: Context) -> None:
conn.close()


def output_table(title: str, column_names: list[str], rows: list[Any]) -> None:
table = Table(title=title)
for column_name in column_names:
table.add_column(column_name)
for row in rows:
table.add_row(*row)
print(table)


@task
def show_users_table(c: Context) -> None:
conn = sqlite3.connect("file_metadata.db")
cursor = conn.cursor()
cursor.execute("SELECT * FROM userinfo")
results = cursor.fetchall()
conn.close()
rows: list[tuple[str, str, str, str]] = []
column_names = ["id", "username", "hashed_password", "email", "full_name"]
print(
f"[bold cyan]Unformatted results:[/bold cyan]\n[blue]column_names=[/blue][bold purple]{column_names}[/bold purple]\n {results}"
)
for result in results:
rows.append((str(result[0]), result[1], result[2], result[4])) # noqa: PERF401
output_table("[bold cyan]Users Table[/bold cyan]", ["id", "username", "hashed_pwd", "name"], rows)


@task
def bak_db(c: Context) -> None:
timestamp = datetime.now(ZoneInfo("UTC")).strftime("%Y-%m-%d_%H:%M:%S")
c.run(f"cp file_metadata.db file_metadata_{timestamp}.bak.db", echo=True)


@task
def export_reqs(c: Context) -> None:
c.run("uv export --no-emit-project --no-dev --output-file=requirements.txt", echo=True, pty=True)
Loading

0 comments on commit 108c4ee

Please sign in to comment.