Skip to content

Commit

Permalink
Add testing for CSRF fail
Browse files Browse the repository at this point in the history
  • Loading branch information
mxsasha committed Nov 28, 2024
1 parent 583bd6a commit 52bea40
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 13 deletions.
7 changes: 5 additions & 2 deletions irrd/server/http/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

def set_middleware(app):
testing = os.environ.get("TESTING", False)
if testing:
csrf_disabled = testing and not getattr(app, "force_csrf_in_testing", False)
if csrf_disabled:
logger.info("Running in testing mode, disabling CSRF.")
app.user_middleware = [
# Use asgi-log to work around https://github.com/encode/uvicorn/issues/1384
Expand All @@ -157,7 +158,9 @@ def set_middleware(app):
Middleware(MemoryTrimMiddleware),
Middleware(SessionMiddleware, secret_key=secret_key_derive("web.session_middleware")),
Middleware(
CSRFProtectMiddleware, csrf_secret=secret_key_derive("web.csrf_middleware"), enabled=not testing
CSRFProtectMiddleware,
csrf_secret=secret_key_derive("web.csrf_middleware"),
enabled=not csrf_disabled,
),
auth_middleware,
]
Expand Down
13 changes: 10 additions & 3 deletions irrd/webui/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from asgiref.sync import sync_to_async
from starlette.requests import Request
from starlette.responses import Response
from starlette_wtf import csrf_protect, csrf_token
from starlette_wtf import CSRFError, csrf_protect, csrf_token

from irrd import META_KEY_HTTP_CLIENT_IP
from irrd.conf import get_setting
Expand Down Expand Up @@ -104,13 +104,20 @@ async def rpsl_detail(request: Request, user_mfa_incomplete: bool, session_provi


def optional_csrf_protect(func):
"""
The RPSL update endpoint is special re CSRF: it may be called from
a browser, with typically a valid CSRF token, or from an API call,
without CSRF. Therefore, this decorator tries to validate CSRF,
and if not, tells the endpoint, which will then ignore user session
info and only look at the post data.
"""

@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
decorated_func = csrf_protect(func)
return await decorated_func(*args, csrf_protected=True, **kwargs)
except Exception as e:
print(f"Exception captured: {e}")
except CSRFError:
return await func(*args, csrf_protected=False, **kwargs)

return wrapper
Expand Down
68 changes: 60 additions & 8 deletions irrd/webui/tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from unittest.mock import create_autospec

import pytest
from starlette.testclient import TestClient

from irrd.updates.handler import ChangeSubmissionHandler
from irrd.utils.rpsl_samples import SAMPLE_MNTNER
from irrd.webui import datetime_format

from ...rpsl.rpsl_objects import rpsl_object_from_text
from ...server.http.app import app
from ...storage.database_handler import DatabaseHandler
from ...storage.models import JournalEntryOrigin
from ...updates.parser_state import UpdateRequestType
Expand Down Expand Up @@ -141,7 +143,11 @@ class TestRpslUpdateNoInitial(WebRequestTest):
requires_mfa = False

def test_valid_mntner_logged_in_mfa_complete_no_user_management(
self, irrd_db_session_with_user, test_client, mock_change_submission_handler
self,
irrd_db_session_with_user,
test_client,
mock_change_submission_handler,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -154,7 +160,11 @@ def test_valid_mntner_logged_in_mfa_complete_no_user_management(
assert "(you can not update this mntner itself)" in response.text

def test_valid_mntner_logged_in_mfa_complete_user_management(
self, irrd_db_session_with_user, test_client, mock_change_submission_handler
self,
irrd_db_session_with_user,
test_client,
mock_change_submission_handler,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -173,8 +183,37 @@ def test_valid_mntner_logged_in_mfa_complete_user_management(
assert mock_handler_kwargs["object_texts_blob"] == SAMPLE_MNTNER
assert mock_handler_kwargs["internal_authenticated_user"].pk == user.pk

def test_valid_mntner_logged_in_mfa_complete_user_management_no_csrf(
self,
irrd_db_session_with_user,
test_client,
mock_change_submission_handler,
tmp_gpg_dir,
):
# print(test_client.app.user_middleware)
# raise Exception()
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
self._verify_mfa(test_client)
create_permission(session_provider, user)

app.force_csrf_in_testing = True
with TestClient(app, cookies=test_client.cookies) as client_csrf:
app.force_csrf_in_testing = False
client_csrf.cookies = test_client.cookies
response = client_csrf.post(self.url, data={"data": SAMPLE_MNTNER})
assert response.status_code == 200
assert mock_change_submission_handler.mock_calls[1][0] == "().load_text_blob"
mock_handler_kwargs = mock_change_submission_handler.mock_calls[1][2]
assert mock_handler_kwargs["object_texts_blob"] == SAMPLE_MNTNER
assert mock_handler_kwargs["internal_authenticated_user"] is None

def test_valid_mntner_logged_in_mfa_incomplete_user_management(
self, irrd_db_session_with_user, test_client, mock_change_submission_handler
self,
irrd_db_session_with_user,
test_client,
mock_change_submission_handler,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -192,7 +231,11 @@ def test_valid_mntner_logged_in_mfa_incomplete_user_management(
assert mock_handler_kwargs["internal_authenticated_user"] is None

def test_valid_mntner_not_logged_in(
self, irrd_db_session_with_user, test_client, mock_change_submission_handler
self,
irrd_db_session_with_user,
test_client,
mock_change_submission_handler,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
response = test_client.get(self.url)
Expand All @@ -213,7 +256,10 @@ class TestRpslUpdateWithInitial(WebRequestTest):
requires_mfa = False

def test_valid_mntner_logged_in_mfa_complete_no_user_management(
self, irrd_db_session_with_user, test_client
self,
irrd_db_session_with_user,
test_client,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -227,7 +273,10 @@ def test_valid_mntner_logged_in_mfa_complete_no_user_management(
assert "DUMMYVALUE" in response.text.upper()

def test_valid_mntner_logged_in_mfa_complete_user_management(
self, irrd_db_session_with_user, test_client
self,
irrd_db_session_with_user,
test_client,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -241,7 +290,10 @@ def test_valid_mntner_logged_in_mfa_complete_user_management(
assert "DUMMYVALUE" not in response.text.upper()

def test_valid_mntner_logged_in_mfa_incomplete_user_management(
self, irrd_db_session_with_user, test_client
self,
irrd_db_session_with_user,
test_client,
tmp_gpg_dir,
):
session_provider, user = irrd_db_session_with_user
self._login(test_client, user)
Expand All @@ -252,7 +304,7 @@ def test_valid_mntner_logged_in_mfa_incomplete_user_management(
assert "TEST-MNT" in response.text
assert "DUMMYVALUE" in response.text.upper()

def test_valid_mntner_not_logged_in(self, irrd_db_session_with_user, test_client):
def test_valid_mntner_not_logged_in(self, irrd_db_session_with_user, test_client, tmp_gpg_dir):
session_provider, user = irrd_db_session_with_user
response = test_client.get(self.url)
assert response.status_code == 200
Expand Down

0 comments on commit 52bea40

Please sign in to comment.