Skip to content

Commit

Permalink
add token, register, ban, unban
Browse files Browse the repository at this point in the history
  • Loading branch information
eelcovdw committed Oct 8, 2024
1 parent 674e541 commit 781c973
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 38 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ cython_debug/
**/Thumbs.db

# syft dirs
data/
users/
/data/
/users/
.clients/
keys/
backup/
Expand Down
23 changes: 19 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,27 @@ syftbox = "syftbox.main:main"
pythonpath = ["."]

[project.optional-dependencies]
dev = ["pytest", "pytest-xdist[psutil]", "pytest-cov", "mypy", "uv", "ruff"]
dev = [
"pytest",
"pytest-xdist[psutil]",
"pytest-cov",
"mypy",
"uv",
"ruff",
"httpx",
]


[tool.ruff]
line-length = 88
exclude = ["data", "users", "build", "dist", ".venv"]
line-length = 120
exclude = ["./data", "./users", "build", "dist", ".venv"]

[tool.ruff.lint]
extend-select = ["I"]
select = ["E", "F", "B", "I"]
ignore = [
"B904", # check for raise statements in exception handlers that lack a from clause
"B905", # zip() without an explicit strict= parameter
]

[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = ["fastapi.Depends", "fastapi.Query"]
47 changes: 15 additions & 32 deletions syftbox/server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import contextlib
import json
import os
import random
Expand Down Expand Up @@ -32,6 +33,9 @@
strtobin,
)

from .users.router import user_router
from .users.user import UserManager

current_dir = Path(__file__).parent


Expand All @@ -42,28 +46,6 @@
FOLDERS = [DATA_FOLDER, SNAPSHOT_FOLDER]


def load_list(cls, filepath: str) -> list[Any]:
try:
with open(filepath) as f:
data = f.read()
d = json.loads(data)
ds = []
for di in d:
ds.append(cls(**di))
return ds
except Exception as e:
print(f"Unable to load list file: {filepath}. {e}")
return None


def save_list(obj: Any, filepath: str) -> None:
dicts = []
for d in obj:
dicts.append(d.to_dict())
with open(filepath, "w") as f:
f.write(json.dumps(dicts))


def load_dict(cls, filepath: str) -> list[Any]:
try:
with open(filepath) as f:
Expand Down Expand Up @@ -147,6 +129,7 @@ def create_folders(folders: list[str]) -> None:
os.makedirs(folder, exist_ok=True)


@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("> Starting Server")
Expand All @@ -155,13 +138,19 @@ async def lifespan(app: FastAPI):
print("> Loading Users")
print(USERS)

yield # Run the application
state = {
"user_manager": UserManager(),
}

yield state

print("> Shutting down server")


app = FastAPI(lifespan=lifespan)

app.include_router(user_router)

# Define the ASCII art
ascii_art = r"""
____ __ _ ____
Expand Down Expand Up @@ -210,13 +199,9 @@ def get_file_list(directory="."):
item_path = os.path.join(directory, item)
is_dir = os.path.isdir(item_path)
size = os.path.getsize(item_path) if not is_dir else "-"
mod_time = datetime.fromtimestamp(os.path.getmtime(item_path)).strftime(
"%Y-%m-%d %H:%M:%S"
)
mod_time = datetime.fromtimestamp(os.path.getmtime(item_path)).strftime("%Y-%m-%d %H:%M:%S")

file_list.append(
{"name": item, "is_dir": is_dir, "size": size, "mod_time": mod_time}
)
file_list.append({"name": item, "is_dir": is_dir, "size": size, "mod_time": mod_time})

return sorted(file_list, key=lambda x: (not x["is_dir"], x["name"].lower()))

Expand Down Expand Up @@ -444,9 +429,7 @@ def main() -> None:
args = parser.parse_args()

uvicorn.run(
"syftbox.server.server:app"
if args.debug
else app, # Use import string in debug mode
"syftbox.server.server:app" if args.debug else app, # Use import string in debug mode
host="0.0.0.0",
port=args.port,
log_level="debug" if args.debug else "info",
Expand Down
Empty file.
37 changes: 37 additions & 0 deletions syftbox/server/users/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import secrets
from typing import Annotated, Literal

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials

security = HTTPBasic()

ADMIN_USERNAME = "info@openmined.org"
ADMIN_PASSWORD = "changethis"


def verify_admin_credentials(
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
) -> Literal[True]:
"""
HTTPBasic authentication that checks if the admin credentials are correct.
Args:
credentials (Annotated[HTTPBasicCredentials, Depends): HTTPBasic credentials
Raises:
HTTPException: 401 Unauthorized if the credentials are incorrect
Returns:
bool: True if the credentials are correct
"""
correct_username = secrets.compare_digest(credentials.username, ADMIN_USERNAME)
correct_password = secrets.compare_digest(credentials.password, ADMIN_PASSWORD)

if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect credentials",
headers={"WWW-Authenticate": "Basic"},
)
return True
92 changes: 92 additions & 0 deletions syftbox/server/users/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import fastapi
from fastapi import Depends, HTTPException, Request

from .auth import verify_admin_credentials
from .user import User, UserManager

user_router = fastapi.APIRouter(
prefix="/users",
tags=["users"],
)


def notify_user(user: User) -> None:
print(f"New token {user.email}: {user.token}")


def get_user_manager(request: Request) -> UserManager:
return request.state.user_manager


@user_router.post("/register_tokens")
async def register_tokens(
emails: list[str],
user_manager: UserManager = Depends(get_user_manager),
is_admin: bool = Depends(verify_admin_credentials),
) -> list[User]:
"""
Register tokens for a list of emails.
All users are created in the db with a random token, and an email is sent to each user.
If the user already exists, the existing user is notified again with the same token.
Args:
emails (list[str]): list of emails to register.
is_admin (bool, optional): checks if the user is an admin.
user_manager (UserManager, optional): the user manager. Defaults to Depends(get_user_manager).
Returns:
list[User]: list of users created.
"""
users = []
for email in emails:
user = user_manager.create_token_for_user(email)
users.append(user)
notify_user(user)

return users


@user_router.post("/ban")
async def ban(
email: str,
is_admin: bool = Depends(verify_admin_credentials),
user_manager: UserManager = Depends(get_user_manager),
) -> User:
try:
user = user_manager.ban_user(email)
return user
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))


@user_router.post("/unban")
async def unban(
email: str,
is_admin: bool = Depends(verify_admin_credentials),
user_manager: UserManager = Depends(get_user_manager),
) -> User:
try:
user = user_manager.unban_user(email)
return user
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))


@user_router.post("/register")
async def register(
email: str,
token: str,
user_manager: UserManager = Depends(get_user_manager),
) -> None:
"""Endpoint used by the user to register. This only works if the user has the correct token.
Args:
email (str): user email
token (str): user token, generated by /register_tokens
"""
try:
user = user_manager.register_user(email, token)
print(f"User {user.email} registered")
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
63 changes: 63 additions & 0 deletions syftbox/server/users/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import secrets

from pydantic import BaseModel


class User(BaseModel):
email: str
token: str
is_registered: bool = False
is_banned: bool = False


class UserManager:
def __init__(self):
self.users: dict[str, User] = {}

def get_user(self, email: str) -> User | None:
return self.users.get(email)

def create_token_for_user(self, email: str) -> User:
user = self.get_user(email)
if user is not None:
return user

token = secrets.token_urlsafe(32)
user = User(email=email, token=token)
self.users[email] = user
return user

def register_user(self, email: str, token: str) -> User:
user = self.get_user(email)
if user is None:
raise ValueError(f"User {email} not found")

if user.token != token:
raise ValueError("Invalid token")

user.is_registered = True
return user

def ban_user(self, email: str) -> User:
user = self.get_user(email)
if user is None:
raise ValueError(f"User {email} not found")

user.is_banned = True
return user

def unban_user(self, email: str) -> User:
user = self.get_user(email)
if user is None:
raise ValueError(f"User {email} not found")

user.is_banned = False
return user

def __repr__(self) -> str:
if len(self.users) == 0:
return "UserManager()"
res = "UserManager(\n"
for email, user in self.users.items():
res += f" {email}: {user}\n"
res += ")"
Loading

0 comments on commit 781c973

Please sign in to comment.