diff --git a/.gitignore b/.gitignore index 5335d1a6..af9f85ce 100644 --- a/.gitignore +++ b/.gitignore @@ -165,8 +165,8 @@ cython_debug/ **/Thumbs.db # syft dirs -data/ -users/ +/data/ +/users/ .clients/ keys/ backup/ diff --git a/pyproject.toml b/pyproject.toml index 89499cd3..e15fc38d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,12 +39,27 @@ syftbox = "syftbox.main:main" pythonpath = ["."] [project.optional-dependencies] -dev = ["pytest", "pytest-xdist[psutil]", "pytest-cov", "mypy", "uv", "ruff"] +dev = [ + "pytest", + "pytest-xdist[psutil]", + "pytest-cov", + "mypy", + "uv", + "ruff", + "httpx", +] [tool.ruff] -line-length = 88 -exclude = ["data", "users", "build", "dist", ".venv"] +line-length = 120 +exclude = ["./data", "./users", "build", "dist", ".venv"] [tool.ruff.lint] -extend-select = ["I"] +select = ["E", "F", "B", "I"] +ignore = [ + "B904", # check for raise statements in exception handlers that lack a from clause + "B905", # zip() without an explicit strict= parameter +] + +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"] diff --git a/syftbox/server/server.py b/syftbox/server/server.py index 604512f5..2a9cc1f1 100644 --- a/syftbox/server/server.py +++ b/syftbox/server/server.py @@ -1,4 +1,5 @@ import argparse +import contextlib import json import os import random @@ -32,6 +33,9 @@ strtobin, ) +from .users.router import user_router +from .users.user import UserManager + current_dir = Path(__file__).parent @@ -42,28 +46,6 @@ FOLDERS = [DATA_FOLDER, SNAPSHOT_FOLDER] -def load_list(cls, filepath: str) -> list[Any]: - try: - with open(filepath) as f: - data = f.read() - d = json.loads(data) - ds = [] - for di in d: - ds.append(cls(**di)) - return ds - except Exception as e: - print(f"Unable to load list file: {filepath}. {e}") - return None - - -def save_list(obj: Any, filepath: str) -> None: - dicts = [] - for d in obj: - dicts.append(d.to_dict()) - with open(filepath, "w") as f: - f.write(json.dumps(dicts)) - - def load_dict(cls, filepath: str) -> list[Any]: try: with open(filepath) as f: @@ -147,6 +129,7 @@ def create_folders(folders: list[str]) -> None: os.makedirs(folder, exist_ok=True) +@contextlib.asynccontextmanager async def lifespan(app: FastAPI): # Startup print("> Starting Server") @@ -155,13 +138,19 @@ async def lifespan(app: FastAPI): print("> Loading Users") print(USERS) - yield # Run the application + state = { + "user_manager": UserManager(), + } + + yield state print("> Shutting down server") app = FastAPI(lifespan=lifespan) +app.include_router(user_router) + # Define the ASCII art ascii_art = r""" ____ __ _ ____ @@ -210,13 +199,9 @@ def get_file_list(directory="."): item_path = os.path.join(directory, item) is_dir = os.path.isdir(item_path) size = os.path.getsize(item_path) if not is_dir else "-" - mod_time = datetime.fromtimestamp(os.path.getmtime(item_path)).strftime( - "%Y-%m-%d %H:%M:%S" - ) + mod_time = datetime.fromtimestamp(os.path.getmtime(item_path)).strftime("%Y-%m-%d %H:%M:%S") - file_list.append( - {"name": item, "is_dir": is_dir, "size": size, "mod_time": mod_time} - ) + file_list.append({"name": item, "is_dir": is_dir, "size": size, "mod_time": mod_time}) return sorted(file_list, key=lambda x: (not x["is_dir"], x["name"].lower())) @@ -444,9 +429,7 @@ def main() -> None: args = parser.parse_args() uvicorn.run( - "syftbox.server.server:app" - if args.debug - else app, # Use import string in debug mode + "syftbox.server.server:app" if args.debug else app, # Use import string in debug mode host="0.0.0.0", port=args.port, log_level="debug" if args.debug else "info", diff --git a/syftbox/server/users/__init__.py b/syftbox/server/users/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/syftbox/server/users/auth.py b/syftbox/server/users/auth.py new file mode 100644 index 00000000..d06948d0 --- /dev/null +++ b/syftbox/server/users/auth.py @@ -0,0 +1,37 @@ +import secrets +from typing import Annotated, Literal + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials + +security = HTTPBasic() + +ADMIN_USERNAME = "info@openmined.org" +ADMIN_PASSWORD = "changethis" + + +def verify_admin_credentials( + credentials: Annotated[HTTPBasicCredentials, Depends(security)], +) -> Literal[True]: + """ + HTTPBasic authentication that checks if the admin credentials are correct. + + Args: + credentials (Annotated[HTTPBasicCredentials, Depends): HTTPBasic credentials + + Raises: + HTTPException: 401 Unauthorized if the credentials are incorrect + + Returns: + bool: True if the credentials are correct + """ + correct_username = secrets.compare_digest(credentials.username, ADMIN_USERNAME) + correct_password = secrets.compare_digest(credentials.password, ADMIN_PASSWORD) + + if not (correct_username and correct_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect credentials", + headers={"WWW-Authenticate": "Basic"}, + ) + return True diff --git a/syftbox/server/users/router.py b/syftbox/server/users/router.py new file mode 100644 index 00000000..7d46a6d0 --- /dev/null +++ b/syftbox/server/users/router.py @@ -0,0 +1,92 @@ +import fastapi +from fastapi import Depends, HTTPException, Request + +from .auth import verify_admin_credentials +from .user import User, UserManager + +user_router = fastapi.APIRouter( + prefix="/users", + tags=["users"], +) + + +def notify_user(user: User) -> None: + print(f"New token {user.email}: {user.token}") + + +def get_user_manager(request: Request) -> UserManager: + return request.state.user_manager + + +@user_router.post("/register_tokens") +async def register_tokens( + emails: list[str], + user_manager: UserManager = Depends(get_user_manager), + is_admin: bool = Depends(verify_admin_credentials), +) -> list[User]: + """ + Register tokens for a list of emails. + All users are created in the db with a random token, and an email is sent to each user. + + If the user already exists, the existing user is notified again with the same token. + + Args: + emails (list[str]): list of emails to register. + is_admin (bool, optional): checks if the user is an admin. + user_manager (UserManager, optional): the user manager. Defaults to Depends(get_user_manager). + + Returns: + list[User]: list of users created. + """ + users = [] + for email in emails: + user = user_manager.create_token_for_user(email) + users.append(user) + notify_user(user) + + return users + + +@user_router.post("/ban") +async def ban( + email: str, + is_admin: bool = Depends(verify_admin_credentials), + user_manager: UserManager = Depends(get_user_manager), +) -> User: + try: + user = user_manager.ban_user(email) + return user + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@user_router.post("/unban") +async def unban( + email: str, + is_admin: bool = Depends(verify_admin_credentials), + user_manager: UserManager = Depends(get_user_manager), +) -> User: + try: + user = user_manager.unban_user(email) + return user + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@user_router.post("/register") +async def register( + email: str, + token: str, + user_manager: UserManager = Depends(get_user_manager), +) -> None: + """Endpoint used by the user to register. This only works if the user has the correct token. + + Args: + email (str): user email + token (str): user token, generated by /register_tokens + """ + try: + user = user_manager.register_user(email, token) + print(f"User {user.email} registered") + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/syftbox/server/users/user.py b/syftbox/server/users/user.py new file mode 100644 index 00000000..24a7c42b --- /dev/null +++ b/syftbox/server/users/user.py @@ -0,0 +1,63 @@ +import secrets + +from pydantic import BaseModel + + +class User(BaseModel): + email: str + token: str + is_registered: bool = False + is_banned: bool = False + + +class UserManager: + def __init__(self): + self.users: dict[str, User] = {} + + def get_user(self, email: str) -> User | None: + return self.users.get(email) + + def create_token_for_user(self, email: str) -> User: + user = self.get_user(email) + if user is not None: + return user + + token = secrets.token_urlsafe(32) + user = User(email=email, token=token) + self.users[email] = user + return user + + def register_user(self, email: str, token: str) -> User: + user = self.get_user(email) + if user is None: + raise ValueError(f"User {email} not found") + + if user.token != token: + raise ValueError("Invalid token") + + user.is_registered = True + return user + + def ban_user(self, email: str) -> User: + user = self.get_user(email) + if user is None: + raise ValueError(f"User {email} not found") + + user.is_banned = True + return user + + def unban_user(self, email: str) -> User: + user = self.get_user(email) + if user is None: + raise ValueError(f"User {email} not found") + + user.is_banned = False + return user + + def __repr__(self) -> str: + if len(self.users) == 0: + return "UserManager()" + res = "UserManager(\n" + for email, user in self.users.items(): + res += f" {email}: {user}\n" + res += ")" diff --git a/tests/user_test.py b/tests/user_test.py new file mode 100644 index 00000000..8c4924de --- /dev/null +++ b/tests/user_test.py @@ -0,0 +1,169 @@ +import secrets + +import pytest +from fastapi import HTTPException +from fastapi.security import HTTPBasicCredentials +from fastapi.testclient import TestClient + +from syftbox.server.server import app as server_app +from syftbox.server.users.auth import ADMIN_PASSWORD, ADMIN_USERNAME, verify_admin_credentials +from syftbox.server.users.user import User, UserManager + + +@pytest.fixture(scope="function") +def client(): + with TestClient(server_app) as client: + yield client + + +@pytest.fixture(scope="function") +def user_with_token(client) -> User: + user_manager: UserManager = client.app_state["user_manager"] + return user_manager.create_token_for_user("user@openmined.org") + + +@pytest.fixture(scope="function") +def registered_user(client, user_with_token) -> User: + user_manager: UserManager = client.app_state["user_manager"] + user = user_manager.register_user(user_with_token.email, user_with_token.token) + return user + + +@pytest.fixture(scope="function") +def admin_credentials() -> HTTPBasicCredentials: + return HTTPBasicCredentials(username=ADMIN_USERNAME, password=ADMIN_PASSWORD) + + +def test_verify_admin_credentials(client, admin_credentials): + assert verify_admin_credentials(admin_credentials) + + wrong_email = HTTPBasicCredentials(username="wrong", password=ADMIN_PASSWORD) + with pytest.raises(HTTPException): + verify_admin_credentials(wrong_email) + + wrong_password = HTTPBasicCredentials(username=ADMIN_USERNAME, password="wrong") + with pytest.raises(HTTPException): + verify_admin_credentials(wrong_password) + + # Test when it is used as a dependency + result = client.post( + "/users/register_tokens", + json=["test_user@openmined.org"], + auth=(admin_credentials.username, admin_credentials.password), + ) + assert result.status_code == 200, result.json() + print(result.json()) + + # wrong admin credentials fails + result = client.post( + "/users/register_tokens", + json=["test_user@openmined.org"], + auth=(wrong_password.username, wrong_password.password), + ) + assert result.status_code == 401, result.json() + print(result.json()) + + # no credentials fails + result = client.post( + "/users/register_tokens", + json=["test_user@openmined.org"], + ) + assert result.status_code == 401, result.json() + print(result.json()) + + +def test_register_tokens(client, admin_credentials): + user_manager: UserManager = client.app_state["user_manager"] + + num_users = 3 + emails = [f"user_{i}@openmined.org" for i in range(num_users)] + + result = client.post( + "/users/register_tokens", + json=emails, + auth=(admin_credentials.username, admin_credentials.password), + ) + result.raise_for_status() + content = result.json() + assert len(content) == num_users + + # all users exist + for email in emails: + user = user_manager.get_user(email) + assert user is not None + assert user.email == email + assert not user.is_banned and not user.is_registered # not banned, not registered + + +def test_ban_non_existing_user(client, admin_credentials): + result = client.post( + "/users/ban", + params={"email": "doesnt_exist@openmined.org"}, + auth=(admin_credentials.username, admin_credentials.password), + ) + assert result.status_code == 404, result.json() + + +def test_ban(client, admin_credentials, registered_user): + user_manager: UserManager = client.app_state["user_manager"] + user = user_manager.get_user(registered_user.email) + assert user.is_banned is False + + # require admin credentials + result = client.post( + "/users/ban", + params={"email": registered_user.email}, + ) + assert result.status_code == 401, result.json() + + result = client.post( + "/users/ban", + params={"email": registered_user.email}, + auth=(admin_credentials.username, admin_credentials.password), + ) + result.raise_for_status() + assert user_manager.get_user(registered_user.email).is_banned + + +def test_unban(client, admin_credentials, registered_user): + user_manager: UserManager = client.app_state["user_manager"] + user_manager.ban_user(registered_user.email) + user = user_manager.get_user(registered_user.email) + assert user.is_banned + + # require admin credentials + result = client.post( + "/users/unban", + params={"email": registered_user.email}, + ) + assert result.status_code == 401, result.json() + + result = client.post( + "/users/unban", + params={"email": registered_user.email}, + auth=(admin_credentials.username, admin_credentials.password), + ) + result.raise_for_status() + assert not user_manager.get_user(registered_user.email).is_banned + + +def test_register_user(client, user_with_token): + user_manager: UserManager = client.app_state["user_manager"] + assert user_manager.get_user(user_with_token.email).is_registered is False + + # wrong token + wrong_token = secrets.token_urlsafe(32) + result = client.post( + "/users/register", + params={"email": user_with_token.email, "token": wrong_token}, + ) + assert result.status_code == 404, result.json() + assert user_manager.get_user(user_with_token.email).is_registered is False + + # correct token + result = client.post( + "/users/register", + params={"email": user_with_token.email, "token": user_with_token.token}, + ) + result.raise_for_status() + assert user_manager.get_user(user_with_token.email).is_registered