Skip to content

Commit

Permalink
remove backend csrf-protection
Browse files Browse the repository at this point in the history
  • Loading branch information
jrycw committed Mar 2, 2024
1 parent aec53ff commit bfbd293
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 129 deletions.
3 changes: 0 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from starlette_csrf import CSRFMiddleware

sys.path.append(os.getcwd())

Expand All @@ -27,8 +26,6 @@ def make_app(lifespan):
allow_headers=["*"],
)

app.add_middleware(CSRFMiddleware, secret=settings.secret_csrf)

app.include_router(users.router)
app.include_router(events.router)
app.include_router(health.router)
Expand Down
51 changes: 38 additions & 13 deletions fastui_app/_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,55 @@
import svcs
from fastapi import FastAPI
from httpx import AsyncClient

from .clients import PostPutDeleteAsyncClient
from .clients import (
BackendAsyncClient,
FrontendGetAsyncClient,
FrontendPostPutDeleteAsyncClient,
)
from .config import settings
from .utils import create_post_put_delete_web_client


async def _lifespan(app: FastAPI, registry: svcs.Registry):
# Web client(connect to backend)
base_url = (
f"{settings.backendschema}://{settings.backendhost}:{settings.backendport}"
backend_client = BackendAsyncClient(
base_url=f"{settings.backendschema}://{settings.backendhost}:{settings.backendport}"
)
client = AsyncClient(base_url=base_url)

async def setup_httpx_client():
"""only 1 web client for GET"""
yield client
async def creat_backend_client():
"""1 backend web client"""
yield backend_client

registry.register_factory(
AsyncClient,
setup_httpx_client,
BackendAsyncClient,
creat_backend_client,
)

front_get_client = BackendAsyncClient(
base_url=f"{settings.frontendschema}://{settings.frontendhost}:{settings.frontendport}"
)

async def create_frontend_get_client():
"""1 frontent web GET client"""
yield front_get_client

registry.register_factory(
FrontendGetAsyncClient,
create_frontend_get_client,
)

async def create_frontend_post_put_delete_client():
"""For every post/put/delete, we request 1 specialized web client"""
base_url = f"{settings.frontendschema}://{settings.frontendhost}:{settings.frontendport}"
async with FrontendPostPutDeleteAsyncClient(base_url=base_url) as client:
csrftoken = (await client.get("/")).cookies.get("csrftoken")
# extra_headers = (
# {"headers": {"x-csrftoken": csrftoken}} if csrftoken is not None else {}
# )
csrftoken_dict = {"x-csrftoken": csrftoken} if csrftoken is not None else {}
yield client, csrftoken_dict

registry.register_factory(
PostPutDeleteAsyncClient, create_post_put_delete_web_client
FrontendPostPutDeleteAsyncClient,
create_frontend_post_put_delete_client,
)

yield
Expand Down
10 changes: 9 additions & 1 deletion fastui_app/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from httpx import AsyncClient


class PostPutDeleteAsyncClient(AsyncClient):
class FrontendGetAsyncClient(AsyncClient):
pass


class FrontendPostPutDeleteAsyncClient(AsyncClient):
pass


class BackendAsyncClient(AsyncClient):
pass
23 changes: 13 additions & 10 deletions fastui_app/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from fastui.components.display import DisplayLookup, DisplayMode
from fastui.events import BackEvent, GoToEvent, PageEvent
from fastui.forms import fastui_form
from httpx import AsyncClient

from .clients import PostPutDeleteAsyncClient
from .clients import ( # noqa: F401
BackendAsyncClient,
FrontendGetAsyncClient,
FrontendPostPutDeleteAsyncClient,
)
from .forms import EventCreationForm, EventUpdateForm
from .shared import demo_page
from .utils import _form_event_repr, _raise_for_status
Expand All @@ -27,11 +30,11 @@ async def event_createview(
services: svcs.fastapi.DepContainer,
form: Annotated[EventCreationForm, fastui_form(EventCreationForm)],
):
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
client = await services.aget(BackendAsyncClient)
form_dict = form.model_dump()
if s := form_dict["schedule"]:
form_dict.update(schedule=s.isoformat())
resp = await client.post("/events", json=form_dict, **extra_headers)
resp = await client.post("/events", json=form_dict)
if resp.status_code != HTTPStatus.CREATED:
resp_json = resp.json()
return [
Expand All @@ -50,7 +53,7 @@ async def event_createview(
async def event_detailview(
services: svcs.fastapi.DepContainer, name: str
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
client = await services.aget(BackendAsyncClient)
resp = await client.get("/events", params={"name": name})
resp_json = _raise_for_status(resp) # try using prebuilt_html
event = _form_event_repr(resp_json)
Expand Down Expand Up @@ -139,11 +142,11 @@ async def event_updateview(
form: Annotated[EventUpdateForm, fastui_form(EventUpdateForm)],
name: str,
) -> list[AnyComponent]:
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
client = await services.aget(BackendAsyncClient)
form_dict = {"name": name} | form.model_dump()
if s := form_dict["schedule"]:
form_dict.update(schedule=s.isoformat())
resp = await client.put("/events", json=form_dict, **extra_headers)
resp = await client.put("/events", json=form_dict)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -164,8 +167,8 @@ async def event_deleteview(
services: svcs.fastapi.DepContainer,
name: str,
) -> list[AnyComponent]:
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.delete("/events", params={"name": name}, **extra_headers)
client = await services.aget(BackendAsyncClient)
resp = await client.delete("/events", params={"name": name})
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -181,7 +184,7 @@ async def event_deleteview(
async def event_listview(
services: svcs.fastapi.DepContainer,
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
client = await services.aget(BackendAsyncClient)
resp = await client.get("/events")
resp_json_list = _raise_for_status(resp, HTTPStatus.OK)
events = [_form_event_repr(resp_json) for resp_json in resp_json_list]
Expand Down
6 changes: 5 additions & 1 deletion fastui_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware

# from starlette_csrf import CSRFMiddleware

sys.path.append(os.getcwd())

from fastui_app import common, events, users
from fastui_app import common, events, users # noqa: F401
from fastui_app.config import settings
from fastui_app.lifespan import lifespan

Expand All @@ -23,6 +25,8 @@ def make_app(lifespan):
allow_headers=["*"],
)

# app.add_middleware(CSRFMiddleware, secret=settings.secret_csrf)

# order matters
app.include_router(users.router)
app.include_router(events.router)
Expand Down
30 changes: 17 additions & 13 deletions fastui_app/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from fastui.components.display import DisplayLookup, DisplayMode
from fastui.events import BackEvent, GoToEvent, PageEvent
from fastui.forms import SelectSearchResponse, fastui_form
from httpx import AsyncClient

from .clients import PostPutDeleteAsyncClient
from .clients import ( # noqa: F401
BackendAsyncClient,
FrontendGetAsyncClient,
FrontendPostPutDeleteAsyncClient,
)
from .forms import UserCreationForm, UserUpdateForm
from .shared import demo_page
from .utils import _form_user_repr, _raise_for_status
Expand All @@ -23,7 +26,7 @@
async def user_ilike_searchview(
services: svcs.fastapi.DepContainer, name: str | None = None
):
client = await services.aget(AsyncClient)
client = await services.aget(BackendAsyncClient)
resp = await client.get("/users/search", params={"name": name})
usernames = resp.json()
options = [{"label": name, "value": name} for name in usernames]
Expand All @@ -39,9 +42,8 @@ async def user_createview(
services: svcs.fastapi.DepContainer,
form: Annotated[UserCreationForm, fastui_form(UserCreationForm)],
):
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.post("/users", json=form.model_dump(), **extra_headers)

client = await services.aget(BackendAsyncClient)
resp = await client.post("/users", json=form.model_dump())
# raised, but how to do a full page reload?
# resp_json = _raise_for_status(resp, HTTPStatus.CREATED)

Expand All @@ -62,7 +64,7 @@ async def user_createview(
async def user_detailview(
services: svcs.fastapi.DepContainer, name: str
) -> list[AnyComponent]:
client = await services.aget(AsyncClient)
client = await services.aget(BackendAsyncClient)
resp = await client.get("/users", params={"name": name})
resp_json = _raise_for_status(resp) # try using prebuilt_html
user = _form_user_repr(resp_json)
Expand All @@ -88,7 +90,9 @@ async def user_detailview(
submit_url=f"/api/users/{name}/update/",
loading=[c.Spinner(text="Updating...")],
footer=[],
submit_trigger=PageEvent(name="modal-form-update-user-submit"),
submit_trigger=PageEvent(
name="modal-form-update-user-submit",
),
),
],
footer=[
Expand Down Expand Up @@ -150,9 +154,9 @@ async def user_updateview(
form: Annotated[UserUpdateForm, fastui_form(UserUpdateForm)],
name: str,
) -> list[AnyComponent]:
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
client = await services.aget(BackendAsyncClient)
form_dict = {"name": name} | form.model_dump()
resp = await client.put("/users", json=form_dict, **extra_headers)
resp = await client.put("/users", json=form_dict)
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -171,8 +175,8 @@ async def user_deleteview(
services: svcs.fastapi.DepContainer,
name: str,
) -> list[AnyComponent]:
client, extra_headers = await services.aget(PostPutDeleteAsyncClient)
resp = await client.delete("/users", params={"name": name}, **extra_headers)
client = await services.aget(BackendAsyncClient)
resp = await client.delete("/users", params={"name": name})
if resp.status_code != HTTPStatus.OK:
resp_json = resp.json()
return [
Expand All @@ -192,7 +196,7 @@ async def user_listview(
Show a table of four users, `/api` is the endpoint the frontend will connect to
when a user visits `/` to fetch components to render.
"""
client = await services.aget(AsyncClient)
client = await services.aget(BackendAsyncClient)
resp = await client.get("/users")
resp_json_list = _raise_for_status(resp, HTTPStatus.OK)
users = [_form_user_repr(resp_json) for resp_json in resp_json_list]
Expand Down
22 changes: 0 additions & 22 deletions fastui_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,10 @@

from fastapi import HTTPException
from fastapi.responses import Response
from httpx import AsyncClient

from .config import settings
from .forms import EventRepr, UserRepr


async def create_get_web_client():
base_url = (
f"{settings.backendschema}://{settings.backendhost}:{settings.backendport}"
)
async with AsyncClient(base_url=base_url) as client:
yield client


async def create_post_put_delete_web_client():
base_url = (
f"{settings.backendschema}://{settings.backendhost}:{settings.backendport}"
)
async with AsyncClient(base_url=base_url) as client:
csrftoken = (await client.get("/")).cookies.get("csrftoken")
extra_headers = (
{"headers": {"x-csrftoken": csrftoken}} if csrftoken is not None else {}
)
yield client, extra_headers


def _raise_for_status(response: Response, status_code: HTTPStatus = HTTPStatus.OK):
resp_json = response.json()
if response.status_code != status_code:
Expand Down
10 changes: 0 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,6 @@ def test_client(test_app):
yield client


@pytest.fixture(scope="function")
def csrftoken(test_client):
yield (test_client.get("/")).cookies.get("csrftoken")


@pytest.fixture(scope="function")
def extra_headers(csrftoken):
yield {"headers": {"x-csrftoken": csrftoken}} if csrftoken is not None else {}


@pytest.fixture
def test_db_client():
yield Mock(spec_set=AsyncIOClient)
Expand Down
Loading

0 comments on commit bfbd293

Please sign in to comment.