Skip to content

Commit

Permalink
try csrf-protection
Browse files Browse the repository at this point in the history
  • Loading branch information
jrycw committed Mar 2, 2024
1 parent 629770c commit aec53ff
Show file tree
Hide file tree
Showing 19 changed files with 184 additions and 89 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ jobs:
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Set env
run: echo "secret_csrf=$(openssl rand -hex 32)" >> $GITHUB_ENV

- name: Test with pytest
run: pytest
5 changes: 3 additions & 2 deletions app/_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ async def _lifespan(app: FastAPI, registry: svcs.Registry, prefill: bool = False
# EdgeDB client
db_client = edgedb.create_async_client()

async def setup_db_client():
async def create_db_client():
"""only 1 db_client"""
yield db_client

async def ping_db_callable(_db_client):
return await _db_client.query("select 1;")

registry.register_factory(
AsyncIOClient,
setup_db_client,
create_db_client,
ping=ping_db_callable,
)

Expand Down
7 changes: 3 additions & 4 deletions app/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from fastapi import APIRouter, Request
from fastapi import APIRouter

router = APIRouter(include_in_schema=False)


@router.get("/")
async def home(request: Request):
client_host = request.client.host
return {"message": "Hello World from FastAPI", "client_host": client_host}
async def home():
return {"message": "Hello World from FastAPI"}
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Settings(BaseSettings):
backendreload: bool = False

tz: str = "UTC"
secret_csrf: str

model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

Expand Down
18 changes: 9 additions & 9 deletions app/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ async def get_events(
services: svcs.fastapi.DepContainer,
name: Annotated[str | None, Query(max_length=50)] = None,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
if name is None:
return await get_events_qry.get_events(client)
return await get_events_qry.get_events(db_client)
else:
if event := await get_event_by_name_qry.get_event_by_name(client, name=name):
if event := await get_event_by_name_qry.get_event_by_name(db_client, name=name):
return event

raise HTTPException(
Expand All @@ -59,10 +59,10 @@ async def post_event(
services: svcs.fastapi.DepContainer,
event: EventCreate,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
try:
created_event = await create_event_qry.create_event(
client, **event.model_dump()
db_client, **event.model_dump()
)
except edgedb.errors.InvalidArgumentError:
raise HTTPException(
Expand Down Expand Up @@ -96,10 +96,10 @@ async def put_event(
services: svcs.fastapi.DepContainer,
event: EventUpdate,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
try:
updated_event = await update_event_qry.update_event(
client, **event.model_dump()
db_client, **event.model_dump()
)
except edgedb.errors.InvalidArgumentError:
raise HTTPException(
Expand Down Expand Up @@ -137,8 +137,8 @@ async def delete_event(
services: svcs.fastapi.DepContainer,
name: Annotated[str, Query(max_length=50)],
):
client = await services.aget(AsyncIOClient)
if deleted_event := await delete_event_qry.delete_event(client, name=name):
db_client = await services.aget(AsyncIOClient)
if deleted_event := await delete_event_qry.delete_event(db_client, name=name):
return deleted_event

raise HTTPException(
Expand Down
11 changes: 8 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette_csrf import CSRFMiddleware

sys.path.append(os.getcwd())

from app import common, events, health, users
from app.config import settings
from app.lifespan import lifespan
Expand All @@ -16,14 +18,17 @@ def make_app(lifespan):

app.add_middleware(
CORSMiddleware,
allow_origins=[
f"{settings.frontendschema}://{settings.frontendhost}:{settings.frontendport}"
],
# allow_origins=[
# f"{settings.frontendschema}://{settings.frontendhost}:{settings.frontendport}"
# ],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.add_middleware(CSRFMiddleware, secret=settings.secret_csrf)

app.include_router(users.router)
app.include_router(events.router)
app.include_router(health.router)
Expand Down
22 changes: 11 additions & 11 deletions app/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ async def search_users_ilike(
services: svcs.fastapi.DepContainer,
name: Annotated[str | None, Query(max_length=50)] = None,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
return await search_users_by_name_ilike_qry.search_users_by_name_ilike(
client, name=name
db_client, name=name
)


Expand All @@ -54,11 +54,11 @@ async def get_users(
services: svcs.fastapi.DepContainer,
name: Annotated[str | None, Query(max_length=50)] = None,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
if name is None:
return await get_users_qry.get_users(client)
return await get_users_qry.get_users(db_client)
else:
if user := await get_user_by_name_qry.get_user_by_name(client, name=name):
if user := await get_user_by_name_qry.get_user_by_name(db_client, name=name):
return user

raise HTTPException(
Expand All @@ -79,9 +79,9 @@ async def get_users(
tags=["users"],
)
async def post_user(services: svcs.fastapi.DepContainer, user: UserCreate):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
try:
created_user = await create_user_qry.create_user(client, **user.model_dump())
created_user = await create_user_qry.create_user(db_client, **user.model_dump())
except edgedb.errors.ConstraintViolationError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
Expand All @@ -105,9 +105,9 @@ async def put_user(
services: svcs.fastapi.DepContainer,
user: UserUpdate,
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
try:
updated_user = await update_user_qry.update_user(client, **user.model_dump())
updated_user = await update_user_qry.update_user(db_client, **user.model_dump())
except edgedb.errors.ConstraintViolationError:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
Expand Down Expand Up @@ -136,10 +136,10 @@ async def put_user(
async def delete_user(
services: svcs.fastapi.DepContainer, name: Annotated[str, Query(max_length=50)]
):
client = await services.aget(AsyncIOClient)
db_client = await services.aget(AsyncIOClient)
try:
deleted_user = await delete_user_qry.delete_user(
client,
db_client,
name=name,
)
except edgedb.errors.ConstraintViolationError:
Expand Down
9 changes: 8 additions & 1 deletion fastui_app/_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from fastapi import FastAPI
from httpx import AsyncClient

from fastui_app.config import settings
from .clients import PostPutDeleteAsyncClient
from .config import settings
from .utils import create_post_put_delete_web_client


async def _lifespan(app: FastAPI, registry: svcs.Registry):
Expand All @@ -13,13 +15,18 @@ async def _lifespan(app: FastAPI, registry: svcs.Registry):
client = AsyncClient(base_url=base_url)

async def setup_httpx_client():
"""only 1 web client for GET"""
yield client

registry.register_factory(
AsyncClient,
setup_httpx_client,
)

registry.register_factory(
PostPutDeleteAsyncClient, create_post_put_delete_web_client
)

yield

await registry.aclose()
5 changes: 5 additions & 0 deletions fastui_app/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from httpx import AsyncClient


class PostPutDeleteAsyncClient(AsyncClient):
pass
1 change: 1 addition & 0 deletions fastui_app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Settings(BaseSettings):
backendreload: bool = False

tz: str = "UTC"
secret_csrf: str

model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

Expand Down
13 changes: 7 additions & 6 deletions fastui_app/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastui.forms import fastui_form
from httpx import AsyncClient

from .clients import PostPutDeleteAsyncClient
from .forms import EventCreationForm, EventUpdateForm
from .shared import demo_page
from .utils import _form_event_repr, _raise_for_status
Expand All @@ -26,11 +27,11 @@ async def event_createview(
services: svcs.fastapi.DepContainer,
form: Annotated[EventCreationForm, fastui_form(EventCreationForm)],
):
client = await services.aget(AsyncClient)
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
form_dict = form.model_dump()
if s := form_dict["schedule"]:
form_dict.update(schedule=s.isoformat())
resp = await client.post("/events", json=form_dict)
resp = await client.post("/events", json=form_dict, **extra_headers)
if resp.status_code != HTTPStatus.CREATED:
resp_json = resp.json()
return [
Expand Down Expand Up @@ -138,11 +139,11 @@ async def event_updateview(
form: Annotated[EventUpdateForm, fastui_form(EventUpdateForm)],
name: str,
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
form_dict = {"name": name} | form.model_dump()
if s := form_dict["schedule"]:
form_dict.update(schedule=s.isoformat())
resp = await client.put("/events", json=form_dict)
resp = await client.put("/events", json=form_dict, **extra_headers)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -163,8 +164,8 @@ async def event_deleteview(
services: svcs.fastapi.DepContainer,
name: str,
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
resp = await client.delete("/events", params={"name": name})
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.delete("/events", params={"name": name}, **extra_headers)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand Down
16 changes: 10 additions & 6 deletions fastui_app/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,15 @@
user_prefer_zoneinfo = ZoneInfo(settings.tz)


class DisplayAuditable:
class _DisplayAuditable:
@field_validator("created_at")
@classmethod
def validate_created_at(cls, v: datetime.datetime) -> datetime.datetime:
"""This function decides how to render `Created_at`"""
return v.astimezone(user_prefer_zoneinfo).replace(microsecond=0)


class UserRepr(UserFull, DisplayAuditable):
pass


class EventRepr(EventFull, DisplayAuditable):
class _DisplaySchedule:
@field_validator("schedule")
@classmethod
def validate_schedule(cls, v: str | None) -> str | None:
Expand All @@ -54,6 +50,14 @@ def validate_schedule(cls, v: str | None) -> str | None:
return v


class UserRepr(UserFull, _DisplayAuditable):
pass


class EventRepr(EventFull, _DisplaySchedule, _DisplayAuditable):
pass


################################
# User
################################
Expand Down
1 change: 1 addition & 0 deletions fastui_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starlette.middleware.cors import CORSMiddleware

sys.path.append(os.getcwd())

from fastui_app import common, events, users
from fastui_app.config import settings
from fastui_app.lifespan import lifespan
Expand Down
14 changes: 8 additions & 6 deletions fastui_app/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from fastui.forms import SelectSearchResponse, fastui_form
from httpx import AsyncClient

from .clients import PostPutDeleteAsyncClient
from .forms import UserCreationForm, UserUpdateForm
from .shared import demo_page
from .utils import _form_user_repr, _raise_for_status

router = APIRouter(include_in_schema=False)


# TODO: Not ready
@router.get("/api/users/search", response_model=SelectSearchResponse)
async def user_ilike_searchview(
services: svcs.fastapi.DepContainer, name: str | None = None
Expand All @@ -37,8 +39,8 @@ async def user_createview(
services: svcs.fastapi.DepContainer,
form: Annotated[UserCreationForm, fastui_form(UserCreationForm)],
):
client = await services.aget(AsyncClient)
resp = await client.post("/users", json=form.model_dump())
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.post("/users", json=form.model_dump(), **extra_headers)

# raised, but how to do a full page reload?
# resp_json = _raise_for_status(resp, HTTPStatus.CREATED)
Expand Down Expand Up @@ -148,9 +150,9 @@ async def user_updateview(
form: Annotated[UserUpdateForm, fastui_form(UserUpdateForm)],
name: str,
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
form_dict = {"name": name} | form.model_dump()
resp = await client.put("/users", json=form_dict)
resp = await client.put("/users", json=form_dict, **extra_headers)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -169,8 +171,8 @@ async def user_deleteview(
services: svcs.fastapi.DepContainer,
name: str,
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
resp = await client.delete("/users", params={"name": name})
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.delete("/users", params={"name": name}, **extra_headers)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand Down
Loading

0 comments on commit aec53ff

Please sign in to comment.